ikraamkb's picture
Update qtAnswering/main.py
2d75ddd verified
raw
history blame
2.25 kB
from fastapi import FastAPI, UploadFile, Form, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
from fastapi.templating import Jinja2Templates
import shutil, os
from tempfile import gettempdir
# βœ… Create app
app = FastAPI()
# βœ… CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# βœ… Templates
templates = Jinja2Templates(directory="templates")
# βœ… Serve Homepage
@app.get("/", response_class=HTMLResponse)
async def serve_home(request: Request):
return templates.TemplateResponse("home.html", {"request": request})
# βœ… Predict endpoint (handles image + document)
@app.post("/predict")
async def predict(question: str = Form(...), file: UploadFile = Form(...)):
try:
temp_path = f"temp_{file.filename}"
with open(temp_path, "wb") as f:
shutil.copyfileobj(file.file, f)
is_image = file.content_type.startswith("image/")
if is_image:
from .appImage import answer_question_from_image
from PIL import Image
image = Image.open(temp_path).convert("RGB")
answer, audio_path = answer_question_from_image(image, question)
else:
from .app import answer_question_from_doc
class NamedFile:
def __init__(self, name): self.filename = name
def read(self): return open(self.filename, "rb").read()
answer, audio_path = answer_question_from_doc(NamedFile(temp_path), question)
os.remove(temp_path)
if audio_path and os.path.exists(audio_path):
return JSONResponse({
"answer": answer,
"audio": f"/qtAnswering/audio/{os.path.basename(audio_path)}"
})
else:
return JSONResponse({"answer": answer})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# βœ… Serve audio files
@app.get("/audio/{filename}")
async def get_audio(filename: str):
filepath = os.path.join(gettempdir(), filename)
return FileResponse(filepath, media_type="audio/mpeg")