File size: 6,190 Bytes
5f2fd70 636399f 8df3bfc 636399f 6805517 8cdfbf3 5f2fd70 636399f 448c445 636399f 448c445 685129e 8df3bfc 448c445 636399f 25b7ae1 636399f 5f2fd70 636399f 5f2fd70 636399f 8cdfbf3 636399f 25b7ae1 8cdfbf3 b1b5846 8cdfbf3 25b7ae1 636399f 25b7ae1 8cdfbf3 25b7ae1 8cdfbf3 636399f 25b7ae1 8cdfbf3 25b7ae1 8df3bfc 25b7ae1 5f2fd70 636399f 8cdfbf3 636399f b1b5846 636399f 8cdfbf3 d7e7353 8cdfbf3 636399f 8cdfbf3 636399f 8cdfbf3 636399f 8cdfbf3 636399f 8cdfbf3 636399f 8cdfbf3 636399f 5f2fd70 636399f 8cdfbf3 d1b3892 8cdfbf3 685129e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import logging
import re
app = FastAPI()
# Enable CORS if needed
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this to your frontend URL
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
####################################
# Text Generation Endpoint
####################################
TEXT_MODEL_NAME = "aubmindlab/aragpt2-medium"
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
text_model = AutoModelForCausalLM.from_pretrained(TEXT_MODEL_NAME)
general_prompt_template = """
أنت الآن نموذج لغة مخصص لتوليد نصوص عربية تعليمية بناءً على المادة والمستوى التعليمي. سيتم إعطاؤك مادة تعليمية ومستوى تعليمي، وعليك إنشاء نص مناسب بناءً على ذلك. النص يجب أن يكون:
1. واضحًا وسهل الفهم.
2. مناسبًا للمستوى التعليمي المحدد.
3. مرتبطًا بالمادة التعليمية المطلوبة.
4. قصيرًا ومباشرًا.
### أمثلة:
1. المادة: العلوم
المستوى: الابتدائي
النص: النباتات كائنات حية تحتاج إلى الماء والهواء وضوء الشمس لتنمو. بعض النباتات تنتج أزهارًا وفواكه. النباتات تساعدنا في الحصول على الأكسجين.
2. المادة: التاريخ
المستوى: المتوسط
النص: التاريخ هو دراسة الماضي وأحداثه المهمة. من خلال التاريخ، نتعلم عن الحضارات القديمة مثل الحضارة الفرعونية والحضارة الإسلامية. التاريخ يساعدنا على فهم تطور البشرية.
3. المادة: الجغرافيا
المستوى: المتوسط
النص: الجغرافيا هي دراسة الأرض وخصائصها. نتعلم عن القارات والمحيطات والجبال. الجغرافيا تساعدنا على فهم العالم من حولنا.
---
المادة: {المادة}
المستوى: {المستوى}
اكتب نصًا مناسبًا بناءً على المادة والمستوى المحددين. ركّز على جعل النص بسيطًا وواضحًا للمستوى المطلوب.
"""
class GenerateTextRequest(BaseModel):
المادة: str
المستوى: str
@app.post("/generate-text")
def generate_text(request: GenerateTextRequest):
if not request.المادة or not request.المستوى:
raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
try:
prompt = general_prompt_template.format(المادة=request.المادة, المستوى=request.المستوى)
inputs = text_tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
with torch.no_grad():
outputs = text_model.generate(
inputs.input_ids,
max_length=300,
num_return_sequences=1,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
generated_text = text_tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
logger.info(f"Generated text: {generated_text}")
return {"generated_text": generated_text}
except Exception as e:
logger.error(f"Error during text generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
####################################
# Question & Answer Generation Model
####################################
QA_MODEL_NAME = "Mihakram/AraT5-base-question-generation"
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
qa_model = AutoModelForSeq2SeqLM.from_pretrained(QA_MODEL_NAME)
def extract_answer(context: str) -> str:
"""Extract the first sentence (or a key phrase) from the context."""
sentences = re.split(r'[.!؟]', context)
sentences = [s.strip() for s in sentences if s.strip()]
return sentences[0] if sentences else context
def get_question(context: str, answer: str) -> str:
"""Generate a question based on the context and the candidate answer."""
text = f"النص: {context} الإجابة: {answer} </s>"
text_encoding = qa_tokenizer.encode_plus(text, return_tensors="pt")
with torch.no_grad():
generated_ids = qa_model.generate(
input_ids=text_encoding['input_ids'],
attention_mask=text_encoding['attention_mask'],
max_length=64,
num_beams=5,
num_return_sequences=1
)
question = qa_tokenizer.decode(generated_ids[0], skip_special_tokens=True).replace("question:", "").strip()
return question
class GenerateQARequest(BaseModel):
text: str
@app.post("/generate-qa")
def generate_qa(request: GenerateQARequest):
if not request.text:
raise HTTPException(status_code=400, detail="Text is required.")
try:
question, answer = get_question(request.text, extract_answer(request.text))
logger.info(f"Generated QA -> Question: {question}, Answer: {answer}")
return {"question": question, "answer": answer}
except Exception as e:
logger.error(f"Error during QA generation: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during QA generation: {str(e)}")
####################################
# Root Endpoint
####################################
@app.get("/")
def read_root():
return {"message": "Welcome to the Arabic Text Generation API!"}
###################################
# Running the FastAPI Server
####################################
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|