Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2Se
|
|
5 |
import logging
|
6 |
import re
|
7 |
|
8 |
-
app = FastAPI(
|
9 |
|
10 |
# Enable CORS if needed
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
@@ -56,40 +56,37 @@ class GenerateTextRequest(BaseModel):
|
|
56 |
|
57 |
@app.post("/generate-text")
|
58 |
def generate_text(request: GenerateTextRequest):
|
59 |
-
|
60 |
-
المستوى = request.المستوى
|
61 |
-
|
62 |
-
if not المادة or not المستوى:
|
63 |
raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
|
64 |
|
65 |
try:
|
66 |
-
prompt = general_prompt_template.format(
|
67 |
-
inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
|
|
|
68 |
with torch.no_grad():
|
69 |
outputs = text_model.generate(
|
70 |
inputs.input_ids,
|
71 |
max_length=300,
|
72 |
num_return_sequences=1,
|
73 |
-
temperature=0.
|
74 |
-
top_p=0.
|
75 |
do_sample=True,
|
76 |
)
|
77 |
-
|
78 |
-
|
79 |
-
generated_text = generated_text.replace(prompt, "").strip()
|
80 |
logger.info(f"Generated text: {generated_text}")
|
81 |
return {"generated_text": generated_text}
|
|
|
82 |
except Exception as e:
|
83 |
logger.error(f"Error during text generation: {str(e)}")
|
84 |
raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
|
85 |
|
86 |
####################################
|
87 |
-
# Question & Answer Generation
|
88 |
####################################
|
89 |
-
|
90 |
QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
|
91 |
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
|
92 |
-
qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
|
93 |
|
94 |
def extract_answer(context: str) -> str:
|
95 |
"""Extract the first sentence (or a key phrase) from the context."""
|
@@ -99,45 +96,48 @@ def extract_answer(context: str) -> str:
|
|
99 |
|
100 |
def get_question(context: str, answer: str) -> str:
|
101 |
"""Generate a question based on the context and the candidate answer."""
|
102 |
-
text = "النص:
|
103 |
-
text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
question =
|
115 |
return question
|
116 |
|
117 |
-
def generate_question_answer(context: str):
|
118 |
-
answer = extract_answer(context)
|
119 |
-
question = get_question(context, answer)
|
120 |
-
return question, answer
|
121 |
-
|
122 |
class GenerateQARequest(BaseModel):
|
123 |
text: str
|
124 |
|
125 |
@app.post("/generate-qa")
|
126 |
def generate_qa(request: GenerateQARequest):
|
127 |
-
|
128 |
-
if not context:
|
129 |
raise HTTPException(status_code=400, detail="Text is required.")
|
|
|
130 |
try:
|
131 |
-
question, answer =
|
132 |
logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
|
133 |
return {"question": question, "answer": answer}
|
|
|
134 |
except Exception as e:
|
135 |
logger.error(f"Error during QA generation: {str(e)}")
|
136 |
raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
|
137 |
|
|
|
|
|
|
|
138 |
@app.get("/")
|
139 |
def read_root():
|
140 |
return {"message": "Welcome to the Arabic Text Generation API!"}
|
|
|
|
|
|
|
|
|
141 |
if __name__ == "__main__":
|
142 |
import uvicorn
|
143 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
5 |
import logging
|
6 |
import re
|
7 |
|
8 |
+
app = FastAPI()
|
9 |
|
10 |
# Enable CORS if needed
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
56 |
|
57 |
@app.post("/generate-text")
|
58 |
def generate_text(request: GenerateTextRequest):
|
59 |
+
if not request.المادة or not request.المستوى:
|
|
|
|
|
|
|
60 |
raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
|
61 |
|
62 |
try:
|
63 |
+
prompt = general_prompt_template.format(المادة=request.المادة, المستوى=request.المستوى)
|
64 |
+
inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(device)
|
65 |
+
|
66 |
with torch.no_grad():
|
67 |
outputs = text_model.generate(
|
68 |
inputs.input_ids,
|
69 |
max_length=300,
|
70 |
num_return_sequences=1,
|
71 |
+
temperature=0.7,
|
72 |
+
top_p=0.95,
|
73 |
do_sample=True,
|
74 |
)
|
75 |
+
|
76 |
+
generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
|
|
|
77 |
logger.info(f"Generated text: {generated_text}")
|
78 |
return {"generated_text": generated_text}
|
79 |
+
|
80 |
except Exception as e:
|
81 |
logger.error(f"Error during text generation: {str(e)}")
|
82 |
raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
|
83 |
|
84 |
####################################
|
85 |
+
# Question & Answer Generation Model
|
86 |
####################################
|
|
|
87 |
QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
|
88 |
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
|
89 |
+
qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME).to(device)
|
90 |
|
91 |
def extract_answer(context: str) -> str:
|
92 |
"""Extract the first sentence (or a key phrase) from the context."""
|
|
|
96 |
|
97 |
def get_question(context: str, answer: str) -> str:
|
98 |
"""Generate a question based on the context and the candidate answer."""
|
99 |
+
text = f"النص: {context} الإجابة: {answer} </s>"
|
100 |
+
text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt").to(device)
|
101 |
+
|
102 |
+
with torch.no_grad():
|
103 |
+
generated_ids = qa_model.generate(
|
104 |
+
input_ids=text_encoding['input_ids'],
|
105 |
+
attention_mask=text_encoding['attention_mask'],
|
106 |
+
max_length=64,
|
107 |
+
num_beams=5,
|
108 |
+
num_return_sequences=1
|
109 |
+
)
|
110 |
+
|
111 |
+
question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace("question:", "").strip()
|
112 |
return question
|
113 |
|
|
|
|
|
|
|
|
|
|
|
114 |
class GenerateQARequest(BaseModel):
|
115 |
text: str
|
116 |
|
117 |
@app.post("/generate-qa")
|
118 |
def generate_qa(request: GenerateQARequest):
|
119 |
+
if not request.text:
|
|
|
120 |
raise HTTPException(status_code=400, detail="Text is required.")
|
121 |
+
|
122 |
try:
|
123 |
+
question, answer = get_question(request.text, extract_answer(request.text))
|
124 |
logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
|
125 |
return {"question": question, "answer": answer}
|
126 |
+
|
127 |
except Exception as e:
|
128 |
logger.error(f"Error during QA generation: {str(e)}")
|
129 |
raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
|
130 |
|
131 |
+
####################################
|
132 |
+
# Root Endpoint
|
133 |
+
####################################
|
134 |
@app.get("/")
|
135 |
def read_root():
|
136 |
return {"message": "Welcome to the Arabic Text Generation API!"}
|
137 |
+
|
138 |
+
####################################
|
139 |
+
# Running the FastAPI Server
|
140 |
+
####################################
|
141 |
if __name__ == "__main__":
|
142 |
import uvicorn
|
143 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|