yakine commited on
Commit
636399f
·
verified ·
1 Parent(s): 68394e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -53
app.py CHANGED
@@ -1,93 +1,72 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
- from fastapi.middleware.cors import CORSMiddleware
6
  import logging
7
- from huggingface_hub import HfFolder
8
-
9
- # Configure logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
 
15
  app = FastAPI()
16
 
17
- # Enable CORS
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
- allow_origins=["*"], # Allow all origins (replace with your frontend URL in production)
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
24
  )
25
 
26
- # Load your fine-tuned model and tokenizer
27
- MODEL_NAME = "aubmindlab/aragpt2-medium"
28
 
29
- try:
30
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
32
- except Exception as e:
33
- logger.error(f"Failed to load model or tokenizer: {str(e)}")
34
- raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")
 
35
 
36
- # Define the general prompt template
37
  general_prompt_template = """
38
  أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
39
  1. واضحًا وسهل الفهم.
40
  2. مناسبًا للمستوى التعليمي المحدد.
41
  3. مرتبطًا بالمادة التعليمية المطلوبة.
42
  4. قصيرًا ومباشرًا.
43
-
44
  ### أمثلة:
45
  1. المادة: العلوم
46
  المستوى: الابتدائي
47
  النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
48
-
49
  2. المادة: التاريخ
50
  المستوى: المتوسط
51
  النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
52
-
53
  3. المادة: الجغرافيا
54
  المستوى: المتوسط
55
  النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
56
-
57
  ---
58
-
59
  المادة: {المادة}
60
  المستوى: {المستوى}
61
-
62
  اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
63
  """
64
 
65
- class GenerateRequest(BaseModel):
66
  المادة: str
67
  المستوى: str
68
 
69
- @app.post("/generate")
70
- def generate_text(request: GenerateRequest):
71
  المادة = request.المادة
72
  المستوى = request.المستوى
73
 
74
- logger.info(f"Received request: المادة={المادة}, المستوى={المستوى}")
75
-
76
- if not المادة or not المستوى or not isinstance(المادة, str) or not isinstance(المستوى, str):
77
- logger.error("المادة والمستوى مطلوبان ويجب أن يكونا نصًا.")
78
- raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان ويجب أن يكونا نصًا.")
79
 
80
  try:
81
- # Format the prompt with user inputs
82
- arabic_prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
83
- logger.info(f"Formatted prompt: {arabic_prompt}")
84
-
85
- # Tokenize the prompt
86
- inputs = tokenizer(arabic_prompt, return_tensors="pt", max_length=1024, truncation=True)
87
-
88
- # Generate text
89
  with torch.no_grad():
90
- outputs = model.generate(
91
  inputs.input_ids,
92
  max_length=300,
93
  num_return_sequences=1,
@@ -95,20 +74,67 @@ def generate_text(request: GenerateRequest):
95
  top_p=0.9,
96
  do_sample=True,
97
  )
98
-
99
- # Decode the generated text
100
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
101
- logger.info(f"Generated text: {generated_text}")
102
-
103
  # Remove the prompt from the generated text
104
- generated_text = generated_text.replace(arabic_prompt, "").strip()
105
-
106
  return {"generated_text": generated_text}
107
-
108
  except Exception as e:
109
  logger.error(f"Error during text generation: {str(e)}")
110
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
111
 
112
- @app.post("/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def read_root():
114
- return {"message": "Welcome to the Arabic Text Generation API!"}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
5
  import logging
6
+ import re
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ # Enable CORS if needed
11
+ from fastapi.middleware.cors import CORSMiddleware
12
  app.add_middleware(
13
  CORSMiddleware,
14
+ allow_origins=["*"], # In production, restrict this to your frontend URL
15
  allow_credentials=True,
16
  allow_methods=["*"],
17
  allow_headers=["*"],
18
  )
19
 
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
 
23
+ ####################################
24
+ # Text Generation Endpoint
25
+ ####################################
26
+
27
+ TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
28
+ text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
29
+ text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
30
 
 
31
  general_prompt_template = """
32
  أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
33
  1. واضحًا وسهل الفهم.
34
  2. مناسبًا للمستوى التعليمي المحدد.
35
  3. مرتبطًا بالمادة التعليمية المطلوبة.
36
  4. قصيرًا ومباشرًا.
 
37
  ### أمثلة:
38
  1. المادة: العلوم
39
  المستوى: الابتدائي
40
  النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
 
41
  2. المادة: التاريخ
42
  المستوى: المتوسط
43
  النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
 
44
  3. المادة: الجغرافيا
45
  المستوى: المتوسط
46
  النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
 
47
  ---
 
48
  المادة: {المادة}
49
  المستوى: {المستوى}
 
50
  اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
51
  """
52
 
53
+ class GenerateTextRequest(BaseModel):
54
  المادة: str
55
  المستوى: str
56
 
57
+ @app.post("/generate-text")
58
+ def generate_text(request: GenerateTextRequest):
59
  المادة = request.المادة
60
  المستوى = request.المستوى
61
 
62
+ if not المادة or not المستوى:
63
+ raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
 
 
 
64
 
65
  try:
66
+ prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
67
+ inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
68
  with torch.no_grad():
69
+ outputs = text_model.generate(
70
  inputs.input_ids,
71
  max_length=300,
72
  num_return_sequences=1,
 
74
  top_p=0.9,
75
  do_sample=True,
76
  )
77
+ generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
78
  # Remove the prompt from the generated text
79
+ generated_text = generated_text.replace(prompt, "").strip()
80
+ logger.info(f"Generated text: {generated_text}")
81
  return {"generated_text": generated_text}
 
82
  except Exception as e:
83
  logger.error(f"Error during text generation: {str(e)}")
84
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
85
 
86
+ ####################################
87
+ # Question & Answer Generation Endpoint
88
+ ####################################
89
+
90
+ QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
91
+ qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
92
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
93
+
94
+ def extract_answer(context: str) -> str:
95
+ """Extract the first sentence (or a key phrase) from the context."""
96
+ sentences = re.split(r'[.!؟]', context)
97
+ sentences = [s.strip() for s in sentences if s.strip()]
98
+ return sentences[0] if sentences else context
99
+
100
+ def get_question(context: str, answer: str) -> str:
101
+ """Generate a question based on the context and the candidate answer."""
102
+ text = "النص: " + context + " " + "الإجابة: " + answer + " </s>"
103
+ text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
104
+ qa_model.eval()
105
+ generated_ids = qa_model.generate(
106
+ input_ids=text_encoding['input_ids'],
107
+ attention_mask=text_encoding['attention_mask'],
108
+ max_length=64,
109
+ num_beams=5,
110
+ num_return_sequences=1
111
+ )
112
+ question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
113
+ # Optionally remove a leading phrase if present
114
+ question = question.replace('question: ', '').strip()
115
+ return question
116
+
117
+ def generate_question_answer(context: str):
118
+ answer = extract_answer(context)
119
+ question = get_question(context, answer)
120
+ return question, answer
121
+
122
+ class GenerateQARequest(BaseModel):
123
+ text: str
124
+
125
+ @app.post("/generate-qa")
126
+ def generate_qa(request: GenerateQARequest):
127
+ context = request.text
128
+ if not context:
129
+ raise HTTPException(status_code=400, detail="Text is required.")
130
+ try:
131
+ question, answer = generate_question_answer(context)
132
+ logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
133
+ return {"question": question, "answer": answer}
134
+ except Exception as e:
135
+ logger.error(f"Error during QA generation: {str(e)}")
136
+ raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
137
+
138
+ @app.get("/")
139
  def read_root():
140
+ return {"message": "Welcome to the Arabic Text Generation API!"}