File size: 6,814 Bytes
7c40d04
0133631
2dd1f0a
e1933c4
1ece3c6
7c40d04
 
b0c5829
e1933c4
 
e4872e8
b0c5829
6991b14
7c40d04
 
 
8def51d
2dd1f0a
b0c5829
7c40d04
2dd1f0a
 
b273b4c
7c40d04
b0c5829
e1933c4
 
2dd1f0a
 
 
 
e1933c4
6991b14
0133631
8def51d
0133631
57d09d7
 
0133631
7c40d04
 
0133631
7c40d04
2dd1f0a
 
 
 
 
 
b273b4c
2dd1f0a
b273b4c
2dd1f0a
b273b4c
2dd1f0a
b273b4c
 
 
2dd1f0a
 
 
 
 
 
7c40d04
 
0133631
7c40d04
e1933c4
 
7c40d04
6991b14
e1933c4
 
7c40d04
6991b14
e1933c4
8def51d
2dd1f0a
0133631
e1933c4
0133631
e1933c4
0133631
e1933c4
7c40d04
 
 
 
 
b0c5829
7c40d04
 
b0c5829
0133631
7c40d04
 
b0c5829
 
 
7c40d04
 
8def51d
 
 
 
 
7c40d04
 
2dd1f0a
 
 
b0c5829
2dd1f0a
7c40d04
8def51d
b273b4c
b0c5829
8def51d
2dd1f0a
0133631
 
 
7c40d04
2dd1f0a
 
 
7c40d04
 
b0c5829
7c40d04
 
 
2dd1f0a
 
 
b0c5829
2dd1f0a
7c40d04
2dd1f0a
 
b0c5829
7c40d04
0133631
2dd1f0a
 
8def51d
2dd1f0a
 
 
7c40d04
b273b4c
2dd1f0a
7c40d04
 
2dd1f0a
7c40d04
 
 
2dd1f0a
 
 
b0c5829
2dd1f0a
7c40d04
 
2dd1f0a
 
 
cbdf3eb
7c40d04
 
 
6991b14
e1933c4
7c40d04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os, io
from pathlib import Path

from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
from PyPDF2 import PdfReader
from docx import Document
from PIL import Image
from io import BytesIO

# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
PORT              = int(os.getenv("PORT", 7860))

app = FastAPI(
    title       = "AI-Powered Web-App API",
    description = "Backend for summarisation, captioning & QA",
    version     = "1.2.3",               # <-- bumped
)

app.add_middleware(
    CORSMiddleware,
    allow_origins     = ["*"],
    allow_credentials = True,
    allow_methods     = ["*"],
    allow_headers     = ["*"],
)

# -----------------------------------------------------------------------------
# OPTIONAL STATIC FILES
# -----------------------------------------------------------------------------
static_dir = Path("static")
if static_dir.exists():
    app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")

# -----------------------------------------------------------------------------
# HUGGING FACE INFERENCE CLIENTS
# -----------------------------------------------------------------------------
summary_client        = InferenceClient(
    "facebook/bart-large-cnn",
    token   = HUGGINGFACE_TOKEN,
    timeout = 120,
)

# ➜ Upgraded QA model (higher accuracy than roberta-base)
qa_client             = InferenceClient(
    "deepset/roberta-large-squad2",
    token   = HUGGINGFACE_TOKEN,
    timeout = 120,
)
# If you need multilingual support, swap for:
# qa_client = InferenceClient("deepset/xlm-roberta-large-squad2",
#                             token=HUGGINGFACE_TOKEN, timeout=120)

image_caption_client  = InferenceClient(
    "nlpconnect/vit-gpt2-image-captioning",
    token   = HUGGINGFACE_TOKEN,
    timeout = 60,
)

# -----------------------------------------------------------------------------
# UTILITIES
# -----------------------------------------------------------------------------
def extract_text_from_pdf(content: bytes) -> str:
    reader = PdfReader(io.BytesIO(content))
    return "\n".join(page.extract_text() or "" for page in reader.pages).strip()

def extract_text_from_docx(content: bytes) -> str:
    doc = Document(io.BytesIO(content))
    return "\n".join(p.text for p in doc.paragraphs).strip()

def process_uploaded_file(file: UploadFile) -> str:
    content = file.file.read()
    ext     = file.filename.split(".")[-1].lower()
    if ext == "pdf":
        return extract_text_from_pdf(content)
    if ext == "docx":
        return extract_text_from_docx(content)
    if ext == "txt":
        return content.decode("utf-8").strip()
    raise ValueError("Unsupported file type")

# -----------------------------------------------------------------------------
# ROUTES
# -----------------------------------------------------------------------------
@app.get("/", response_class=HTMLResponse)
async def serve_index():
    return FileResponse("index.html")

# -------------------- Summarisation ------------------------------------------
@app.post("/api/summarize")
async def summarize_document(file: UploadFile = File(...)):
    try:
        text = process_uploaded_file(file)
        if len(text) < 20:
            return {"result": "Document too short to summarise."}
        summary_raw = summary_client.summarization(text[:3000])
        summary_txt = (
            summary_raw[0].get("summary_text") if isinstance(summary_raw, list) else
            summary_raw.get("summary_text")   if isinstance(summary_raw, dict) else
            str(summary_raw)
        )
        return {"result": summary_txt}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"Summarisation failure: {exc}"})


# -------------------- Image Caption ------------------------------------------
@app.post("/api/caption")
async def caption_image(image: UploadFile = File(...)):
    """`image` field name matches frontend (was `file` before)."""
    try:
        img_bytes = await image.read()
        img       = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        img.thumbnail((1024, 1024))
        buf = BytesIO(); img.save(buf, format="JPEG")
        result = image_caption_client.image_to_text(buf.getvalue())
        if isinstance(result, dict):
            caption = (result.get("generated_text")
                       or result.get("caption")
                       or "No caption found.")
        elif isinstance(result, list):
            caption = result[0].get("generated_text", "No caption found.")
        else:
            caption = str(result)
        return {"result": caption}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"Caption failure: {exc}"})


# -------------------- Question Answering -------------------------------------
@app.post("/api/qa")
async def question_answering(file: UploadFile = File(...),
                             question: str = Form(...)):
    try:
        if file.content_type.startswith("image/"):
            img_bytes = await file.read()
            img       = Image.open(io.BytesIO(img_bytes)).convert("RGB")
            img.thumbnail((1024, 1024))
            buf = BytesIO(); img.save(buf, format="JPEG")
            res      = image_caption_client.image_to_text(buf.getvalue())
            context  = (res.get("generated_text") if isinstance(res, dict)
                        else str(res))
        else:
            context = process_uploaded_file(file)[:3000]

        if not context:
            return {"result": "No context – cannot answer."}

        answer = qa_client.question_answering(question=question, context=context)
        return {"result": answer.get("answer", "No answer found.")}
    except Exception as exc:
        return JSONResponse(status_code=500,
                            content={"error": f"QA failure: {exc}"})


# -------------------- Health --------------------------------------------------
@app.get("/api/health")
async def health():
    return {"status": "healthy",
            "hf_token_set": bool(HUGGINGFACE_TOKEN),
            "version": app.version}

# -----------------------------------------------------------------------------
# ENTRYPOINT
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=PORT)