benkada commited on
Commit
e4872e8
·
verified ·
1 Parent(s): cb57f04

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +126 -213
main.py CHANGED
@@ -1,213 +1,126 @@
1
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel
5
- from typing import Optional
6
- import os
7
- import tempfile
8
- from transformers import pipeline
9
- import torch
10
- from PIL import Image
11
- import pytesseract
12
- from langchain.chains import LLMChain
13
- from langchain.prompts import PromptTemplate
14
- from langchain_community.llms import HuggingFaceHub
15
-
16
- # Initialize FastAPI app
17
- app = FastAPI(
18
- title="AI-Powered Web Application API",
19
- description="API for document analysis, image captioning, and question answering",
20
- version="1.0.0"
21
- )
22
-
23
- # CORS configuration
24
- app.add_middleware(
25
- CORSMiddleware,
26
- allow_origins=["*"],
27
- allow_credentials=True,
28
- allow_methods=["*"],
29
- allow_headers=["*"],
30
- )
31
-
32
- # Initialize AI models (lazy loading)
33
- summarizer = None
34
- image_captioner = None
35
- qa_chain = None
36
-
37
- class SummaryRequest(BaseModel):
38
- file: UploadFile = File(...)
39
-
40
- class CaptionRequest(BaseModel):
41
- file: UploadFile = File(...)
42
-
43
- class QARequest(BaseModel):
44
- file: UploadFile = File(...)
45
- question: str = Form(...)
46
-
47
- def initialize_models():
48
- """Initialize AI models with optimized prompts"""
49
- global summarizer, image_captioner, qa_chain
50
-
51
- # Document summarization model
52
- if summarizer is None:
53
- summarizer = pipeline(
54
- "summarization",
55
- model="facebook/bart-large-cnn",
56
- device=0 if torch.cuda.is_available() else -1
57
- )
58
-
59
- # Image captioning model
60
- if image_captioner is None:
61
- image_captioner = pipeline(
62
- "image-to-text",
63
- model="nlpconnect/vit-gpt2-image-captioning",
64
- device=0 if torch.cuda.is_available() else -1
65
- )
66
-
67
- # Question answering chain
68
- if qa_chain is None:
69
- llm = HuggingFaceHub(
70
- repo_id="google/flan-t5-large",
71
- model_kwargs={"temperature": 0.1, "max_length": 512}
72
- )
73
-
74
- qa_prompt = PromptTemplate(
75
- input_variables=["document", "question"],
76
- template="""
77
- Using the provided document, answer the following question precisely.
78
- If the answer cannot be determined from the document, respond with
79
- 'The answer cannot be determined from the provided document.'
80
-
81
- Question: {question}
82
-
83
- Rules:
84
- 1. Provide a concise answer (1-3 sentences maximum)
85
- 2. When possible, reference the specific section of the document that supports your answer
86
- 3. Maintain numerical precision when answering quantitative questions
87
- 4. For comparison questions, highlight both items being compared
88
-
89
- Document: {document}
90
- """
91
- )
92
- qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
93
-
94
- def extract_text_from_file(file: UploadFile) -> str:
95
- """Extract text from various file formats"""
96
- # Create a temporary file
97
- with tempfile.NamedTemporaryFile(delete=False) as temp_file:
98
- temp_file.write(file.file.read())
99
- temp_path = temp_file.name
100
-
101
- try:
102
- # PDF, DOCX, PPTX, XLSX would need appropriate libraries here
103
- # For simplicity, we'll just read text files in this example
104
- if file.filename.endswith('.txt'):
105
- with open(temp_path, 'r', encoding='utf-8') as f:
106
- return f.read()
107
- else:
108
- # In a real implementation, use libraries like PyPDF2, python-docx, etc.
109
- raise HTTPException(
110
- status_code=415,
111
- detail="File type not supported in this example implementation"
112
- )
113
- finally:
114
- os.unlink(temp_path)
115
-
116
- @app.post("/api/summarize")
117
- async def summarize_document(file: UploadFile = File(...)):
118
- """Summarize a document"""
119
- initialize_models()
120
-
121
- try:
122
- # Extract text from the document
123
- document_text = extract_text_from_file(file)
124
-
125
- # Generate summary with optimized prompt
126
- summary = summarizer(
127
- document_text,
128
- max_length=150,
129
- min_length=30,
130
- do_sample=False,
131
- truncation=True
132
- )
133
-
134
- return JSONResponse(
135
- content={"status": "success", "result": summary[0]['summary_text']},
136
- status_code=200
137
- )
138
- except Exception as e:
139
- raise HTTPException(
140
- status_code=500,
141
- detail=f"Error processing document: {str(e)}"
142
- )
143
-
144
- @app.post("/api/caption")
145
- async def generate_image_caption(file: UploadFile = File(...)):
146
- """Generate caption for an image"""
147
- initialize_models()
148
-
149
- try:
150
- # Save the uploaded image temporarily
151
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
152
- temp_file.write(file.file.read())
153
- temp_path = temp_file.name
154
-
155
- # Open the image
156
- image = Image.open(temp_path)
157
-
158
- # Generate caption with optimized prompt
159
- caption = image_captioner(
160
- image,
161
- generate_kwargs={
162
- "max_length": 50,
163
- "num_beams": 4,
164
- "early_stopping": True
165
- }
166
- )
167
-
168
- return JSONResponse(
169
- content={"status": "success", "result": caption[0]['generated_text']},
170
- status_code=200
171
- )
172
- except Exception as e:
173
- raise HTTPException(
174
- status_code=500,
175
- detail=f"Error processing image: {str(e)}"
176
- )
177
- finally:
178
- if 'temp_path' in locals() and os.path.exists(temp_path):
179
- os.unlink(temp_path)
180
-
181
- @app.post("/api/qa")
182
- async def answer_question(
183
- file: UploadFile = File(...),
184
- question: str = Form(...)
185
- ):
186
- """Answer questions based on document content"""
187
- initialize_models()
188
-
189
- try:
190
- # Extract text from the document
191
- document_text = extract_text_from_file(file)
192
-
193
- # Get answer using the QA chain
194
- answer = qa_chain.run(document=document_text, question=question)
195
-
196
- return JSONResponse(
197
- content={"status": "success", "result": answer},
198
- status_code=200
199
- )
200
- except Exception as e:
201
- raise HTTPException(
202
- status_code=500,
203
- detail=f"Error processing question: {str(e)}"
204
- )
205
-
206
- @app.get("/")
207
- async def health_check():
208
- """Health check endpoint"""
209
- return {"status": "healthy", "version": "1.0.0"}
210
-
211
- if __name__ == "__main__":
212
- import uvicorn
213
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import os
2
+ from fastapi import FastAPI, UploadFile, File, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+ from typing import Optional
7
+ from PIL import Image
8
+ import pytesseract
9
+ from transformers import pipeline
10
+ from langchain.chains import LLMChain
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_community.llms import HuggingFaceHub
13
+
14
+ # Ensure HF cache directory is set before any HF import uses it
15
+ os.environ.setdefault("HF_HOME", os.getenv("HF_HOME", "/app/cache"))
16
+
17
+ # FastAPI application
18
+ app = FastAPI(
19
+ title="AI-Powered Web Application API",
20
+ description="API for document summarization, image captioning, and question answering",
21
+ version="1.0.0"
22
+ )
23
+
24
+ # CORS middleware
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # ----------------
34
+ # Schemas
35
+ # ----------------
36
+ class SummarizeRequest(BaseModel):
37
+ text: str
38
+ max_length: Optional[int] = 150
39
+ min_length: Optional[int] = 40
40
+
41
+ class QARequest(BaseModel):
42
+ question: str
43
+ context: Optional[str] = None
44
+
45
+ # ----------------
46
+ # Model loaders (lazy)
47
+ # ----------------
48
+ _cache_dir = os.getenv("HF_HOME", "/app/cache")
49
+ _summarizer = None
50
+ _captioner = None
51
+ _qa_chain = None
52
+
53
+
54
+ def get_summarizer():
55
+ global _summarizer
56
+ if _summarizer is None:
57
+ _summarizer = pipeline(
58
+ "summarization",
59
+ model="facebook/bart-large-cnn",
60
+ cache_dir=_cache_dir
61
+ )
62
+ return _summarizer
63
+
64
+
65
+ def get_image_captioner():
66
+ global _captioner
67
+ if _captioner is None:
68
+ _captioner = pipeline(
69
+ "image-to-text",
70
+ model="nlpconnect/vit-gpt2-image-captioning",
71
+ cache_dir=_cache_dir
72
+ )
73
+ return _captioner
74
+
75
+
76
+ def get_qa_chain():
77
+ global _qa_chain
78
+ if _qa_chain is None:
79
+ llm = HuggingFaceHub(
80
+ repo_id="google/flan-t5-large",
81
+ model_kwargs={"cache_dir": _cache_dir},
82
+ huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
83
+ )
84
+ prompt = PromptTemplate(
85
+ input_variables=["context", "question"],
86
+ template="""
87
+ Use the following context to answer the question:
88
+
89
+ {context}
90
+
91
+ Question: {question}
92
+ Answer:"""
93
+ )
94
+ _qa_chain = LLMChain(llm=llm, prompt=prompt)
95
+ return _qa_chain
96
+
97
+ # ----------------
98
+ # Routes
99
+ # ----------------
100
+ @app.post("/summarize")
101
+ def summarize(req: SummarizeRequest):
102
+ summarizer = get_summarizer()
103
+ result = summarizer(
104
+ req.text,
105
+ max_length=req.max_length,
106
+ min_length=req.min_length,
107
+ clean_up_tokenization_spaces=True
108
+ )
109
+ return JSONResponse(content={"summary": result[0]["summary_text"]})
110
+
111
+ @app.post("/caption")
112
+ async def caption_image(file: UploadFile = File(...)):
113
+ try:
114
+ img = Image.open(file.file).convert("RGB")
115
+ captioner = get_image_captioner()
116
+ result = captioner(img)
117
+ return JSONResponse(content={"caption": result[0]["generated_text"]})
118
+ except Exception as e:
119
+ raise HTTPException(status_code=400, detail=str(e))
120
+
121
+ @app.post("/qa")
122
+ def question_answer(req: QARequest):
123
+ chain = get_qa_chain()
124
+ context = req.context or ""
125
+ answer = chain.run({"context": context, "question": req.question})
126
+ return JSONResponse(content={"answer": answer})