yakine commited on
Commit
8cdfbf3
·
verified ·
1 Parent(s): 84beaee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -36
app.py CHANGED
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2Se
5
  import logging
6
  import re
7
 
8
- app = FastAPI(root_path="/")
9
 
10
  # Enable CORS if needed
11
  from fastapi.middleware.cors import CORSMiddleware
@@ -56,40 +56,37 @@ class GenerateTextRequest(BaseModel):
56
 
57
  @app.post("/generate-text")
58
  def generate_text(request: GenerateTextRequest):
59
- المادة = request.المادة
60
- المستوى = request.المستوى
61
-
62
- if not المادة or not المستوى:
63
  raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
64
 
65
  try:
66
- prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
67
- inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
 
68
  with torch.no_grad():
69
  outputs = text_model.generate(
70
  inputs.input_ids,
71
  max_length=300,
72
  num_return_sequences=1,
73
- temperature=0.1,
74
- top_p=0.9,
75
  do_sample=True,
76
  )
77
- generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True)
78
- # Remove the prompt from the generated text
79
- generated_text = generated_text.replace(prompt, "").strip()
80
  logger.info(f"Generated text: {generated_text}")
81
  return {"generated_text": generated_text}
 
82
  except Exception as e:
83
  logger.error(f"Error during text generation: {str(e)}")
84
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
85
 
86
  ####################################
87
- # Question & Answer Generation Endpoint
88
  ####################################
89
-
90
  QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
91
  qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
92
- qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
93
 
94
  def extract_answer(context: str) -> str:
95
  """Extract the first sentence (or a key phrase) from the context."""
@@ -99,45 +96,48 @@ def extract_answer(context: str) -> str:
99
 
100
  def get_question(context: str, answer: str) -> str:
101
  """Generate a question based on the context and the candidate answer."""
102
- text = "النص: " + context + " " + "الإجابة: " + answer + " </s>"
103
- text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
104
- qa_model.eval()
105
- generated_ids = qa_model.generate(
106
- input_ids=text_encoding['input_ids'],
107
- attention_mask=text_encoding['attention_mask'],
108
- max_length=64,
109
- num_beams=5,
110
- num_return_sequences=1
111
- )
112
- question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
113
- # Optionally remove a leading phrase if present
114
- question = question.replace('question: ', '').strip()
115
  return question
116
 
117
- def generate_question_answer(context: str):
118
- answer = extract_answer(context)
119
- question = get_question(context, answer)
120
- return question, answer
121
-
122
  class GenerateQARequest(BaseModel):
123
  text: str
124
 
125
  @app.post("/generate-qa")
126
  def generate_qa(request: GenerateQARequest):
127
- context = request.text
128
- if not context:
129
  raise HTTPException(status_code=400, detail="Text is required.")
 
130
  try:
131
- question, answer = generate_question_answer(context)
132
  logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
133
  return {"question": question, "answer": answer}
 
134
  except Exception as e:
135
  logger.error(f"Error during QA generation: {str(e)}")
136
  raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
137
 
 
 
 
138
  @app.get("/")
139
  def read_root():
140
  return {"message": "Welcome to the Arabic Text Generation API!"}
 
 
 
 
141
  if __name__ == "__main__":
142
  import uvicorn
143
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
5
  import logging
6
  import re
7
 
8
+ app = FastAPI()
9
 
10
  # Enable CORS if needed
11
  from fastapi.middleware.cors import CORSMiddleware
 
56
 
57
  @app.post("/generate-text")
58
  def generate_text(request: GenerateTextRequest):
59
+ if not request.المادة or not request.المستوى:
 
 
 
60
  raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
61
 
62
  try:
63
+ prompt = general_prompt_template.format(المادة=request.المادة, المستوى=request.المستوى)
64
+ inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(device)
65
+
66
  with torch.no_grad():
67
  outputs = text_model.generate(
68
  inputs.input_ids,
69
  max_length=300,
70
  num_return_sequences=1,
71
+ temperature=0.7,
72
+ top_p=0.95,
73
  do_sample=True,
74
  )
75
+
76
+ generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
 
77
  logger.info(f"Generated text: {generated_text}")
78
  return {"generated_text": generated_text}
79
+
80
  except Exception as e:
81
  logger.error(f"Error during text generation: {str(e)}")
82
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
83
 
84
  ####################################
85
+ # Question & Answer Generation Model
86
  ####################################
 
87
  QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
88
  qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
89
+ qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME).to(device)
90
 
91
  def extract_answer(context: str) -> str:
92
  """Extract the first sentence (or a key phrase) from the context."""
 
96
 
97
  def get_question(context: str, answer: str) -> str:
98
  """Generate a question based on the context and the candidate answer."""
99
+ text = f"النص: {context} الإجابة: {answer} </s>"
100
+ text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt").to(device)
101
+
102
+ with torch.no_grad():
103
+ generated_ids = qa_model.generate(
104
+ input_ids=text_encoding['input_ids'],
105
+ attention_mask=text_encoding['attention_mask'],
106
+ max_length=64,
107
+ num_beams=5,
108
+ num_return_sequences=1
109
+ )
110
+
111
+ question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace("question:", "").strip()
112
  return question
113
 
 
 
 
 
 
114
  class GenerateQARequest(BaseModel):
115
  text: str
116
 
117
  @app.post("/generate-qa")
118
  def generate_qa(request: GenerateQARequest):
119
+ if not request.text:
 
120
  raise HTTPException(status_code=400, detail="Text is required.")
121
+
122
  try:
123
+ question, answer = get_question(request.text, extract_answer(request.text))
124
  logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
125
  return {"question": question, "answer": answer}
126
+
127
  except Exception as e:
128
  logger.error(f"Error during QA generation: {str(e)}")
129
  raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
130
 
131
+ ####################################
132
+ # Root Endpoint
133
+ ####################################
134
  @app.get("/")
135
  def read_root():
136
  return {"message": "Welcome to the Arabic Text Generation API!"}
137
+
138
+ ####################################
139
+ # Running the FastAPI Server
140
+ ####################################
141
  if __name__ == "__main__":
142
  import uvicorn
143
  uvicorn.run(app, host="0.0.0.0", port=7860)