Ankit Shrestha commited on
Commit
2e94917
·
1 Parent(s): b0c7f29

Refactor and remove old endpoints

Browse files
Files changed (2) hide show
  1. main.py +181 -185
  2. requirements.txt +0 -1
main.py CHANGED
@@ -1,3 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # # main.py
2
  # from fastapi import FastAPI, File, UploadFile
3
  # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
@@ -63,219 +137,141 @@
63
  # if __name__ == "__main__":
64
  # import uvicorn
65
  # uvicorn.run(app, host="0.0.0.0", port=7860)
66
-
67
- from fastapi import FastAPI, File, UploadFile, BackgroundTasks
68
- from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
69
- import torch
70
- from io import BytesIO
71
- import os
72
- from dotenv import load_dotenv
73
- from PIL import Image
74
- from huggingface_hub import login
75
- import gc
76
- import logging
77
- from typing import List
78
- import time
79
- import numpy as np
80
- from vllm import LLM, SamplingParams
81
- import torch._dynamo
82
- torch._dynamo.config.suppress_errors = True
83
-
84
- # Configure logging
85
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
86
- logger = logging.getLogger(__name__)
87
-
88
- # Load environment variables
89
- load_dotenv()
90
-
91
- # Set the cache directory to a writable path
92
- os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
93
- token = os.getenv("huggingface_ankit")
94
-
95
- # Login to the Hugging Face Hub
96
- login(token)
97
-
98
- app = FastAPI()
99
-
100
  # Global variables for model and processor
101
- model = None
102
- processor = None
103
- llm = None
104
-
105
- def load_model():
106
- """Load model and processor when needed"""
107
- global model, processor
108
- if model is None:
109
- model_id = "google/paligemma2-3b-mix-448"
110
- logger.info(f"Loading model {model_id}")
111
 
112
- # Load model with memory-efficient settings
113
- model = PaliGemmaForConditionalGeneration.from_pretrained(
114
- model_id,
115
- device_map="auto",
116
- torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency
117
- )
118
- processor = PaliGemmaProcessor.from_pretrained(model_id)
119
- logger.info("Model loaded successfully")
120
-
121
- def load_vllm_model():
122
- global llm
123
- if llm is None:
124
- llm = LLM(
125
- model="google/paligemma2-3b-mix-448",
126
- trust_remote_code=True,
127
- max_model_len=4096,
128
- dtype="float16",
129
- )
130
- def clean_memory():
131
- """Force garbage collection and clear CUDA cache"""
132
- gc.collect()
133
- if torch.cuda.is_available():
134
- torch.cuda.empty_cache()
135
- # Clear GPU cache
136
- torch.cuda.empty_cache()
137
- logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes")
138
- logger.info("Memory cleaned")
139
 
140
- def predict(image):
141
- """Process a single image"""
142
- load_model() # Ensure model is loaded
143
 
144
- # Process input
145
- prompt = "<image> ocr"
146
- model_inputs = processor(text=prompt, images=image, return_tensors="pt")
147
 
148
- # Move to appropriate device
149
- model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
150
 
151
- # Generate with memory optimization
152
- with torch.inference_mode():
153
- generation = model.generate(**model_inputs, max_new_tokens=200)
154
 
155
- # Decode output
156
- decoded = processor.decode(generation[0], skip_special_tokens=True)
157
 
158
- # Clean up intermediates
159
- del model_inputs, generation
160
- clean_memory()
161
- # del model,processor
162
- return decoded
163
 
164
- @app.post("/extract_text")
165
- async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
166
- """Extract text from a single image"""
167
- try:
168
- start_time = time.time()
169
- image = Image.open(BytesIO(await file.read())).convert("RGB")
170
- text = predict(image)
171
 
172
- # Schedule cleanup after response
173
- background_tasks.add_task(clean_memory)
174
 
175
- logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds")
176
- return {"extracted_text": text}
177
- except Exception as e:
178
- logger.error(f"Error processing image: {str(e)}")
179
- return {"error": str(e)}
180
-
181
- @app.post("/batch_extract_text_vllm")
182
- async def batch_extract_text_vllm(background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
183
- try:
184
- start_time = time.time()
185
- load_vllm_model()
186
- results = []
187
- sampling_params = SamplingParams(temperature=0.0,max_tokens=32)
188
- # Load images
189
- images = []
190
- for file in files:
191
- image_data = await file.read()
192
- img = Image.open(BytesIO(image_data)).convert("RGB")
193
- images.append(img)
194
- for image in images:
195
- inputs = {
196
- "prompt": "ocr",
197
- "multi_modal_data": {
198
- "image": image
199
- },
200
- }
201
- outputs = llm.generate(inputs, sampling_params)
202
- for o in outputs:
203
- generated_text = o.outputs[0].text
204
- results.append(" ocr\n"+generated_text)
205
-
206
- logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds")
207
- return {"extracted_texts": results}
208
- except Exception as e:
209
- logger.error(f"Error in batch processing vLLM: {str(e)}")
210
- return {"error": str(e)}
211
-
212
- @app.post("/batch_extract_text")
213
- async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
214
- """Extract text from multiple images with batching"""
215
- try:
216
- start_time = time.time()
217
 
218
- # Limit batch size for memory management
219
- max_batch_size = 32 # Adjust based on your GPU memory
220
 
221
- # if len(files) > 32:
222
- # return {"error": "A maximum of 20 images can be processed at a time."}
223
 
224
- load_model() # Ensure model is loaded
225
 
226
- all_results = []
227
 
228
- # Process in smaller batches
229
- for i in range(0, len(files), max_batch_size):
230
- batch_files = files[i:i+max_batch_size]
231
 
232
- # Load images
233
- images = []
234
- for file in batch_files:
235
- image_data = await file.read()
236
- img = Image.open(BytesIO(image_data)).convert("RGB")
237
- images.append(img)
238
 
239
- # Create batch inputs
240
- prompts = ["<image> ocr"] * len(images)
241
- model_inputs = processor(text=prompts, images=images, return_tensors="pt")
242
 
243
- # Move to appropriate device
244
- model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
245
 
246
- # Generate with memory optimization
247
- with torch.inference_mode():
248
- generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
249
 
250
- # Decode outputs
251
- batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
252
- all_results.extend(batch_results)
253
 
254
- # Clean up batch resources
255
- del model_inputs, generations, images
256
- clean_memory()
257
 
258
- # Schedule cleanup after response
259
- background_tasks.add_task(clean_memory)
260
 
261
- logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds")
262
- return {"extracted_texts": all_results}
263
- except Exception as e:
264
- logger.error(f"Error in batch processing: {str(e)}")
265
- return {"error": str(e)}
266
 
267
 
268
  # Health check endpoint
269
- @app.get("/health")
270
- async def health_check():
271
- # Generate a random image (20x40 pixels) with random RGB values
272
- random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8)
273
 
274
- # Create an image from the random data
275
- image = Image.fromarray(random_data)
276
- predict(image)
277
- clean_memory()
278
- return {"status": "healthy"}
279
 
280
  # if __name__ == "__main__":
281
  # import uvicorn
 
1
+ import time
2
+ from io import BytesIO
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from PIL import Image
6
+ import logging
7
+ from typing import List
8
+ from huggingface_hub import login
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from vllm import LLM, SamplingParams
11
+ import torch
12
+ import torch._dynamo
13
+ torch._dynamo.config.suppress_errors = True
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+ # Set the cache directory to a writable path
23
+ os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
24
+ token = os.getenv("huggingface_ankit")
25
+
26
+ # Login to the Hugging Face Hub
27
+ login(token)
28
+
29
+ app = FastAPI()
30
+
31
+ llm = None
32
+
33
+ def load_vllm_model():
34
+ global llm
35
+ logger.info(f"Loading vLLM model...")
36
+ if llm is None:
37
+ llm = LLM(
38
+ model="google/paligemma2-3b-mix-448",
39
+ trust_remote_code=True,
40
+ max_model_len=4096,
41
+ dtype="float16",
42
+ )
43
+
44
+ @app.post("/batch_extract_text_vllm")
45
+ async def batch_extract_text_vllm(files: List[UploadFile] = File(...)):
46
+ try:
47
+ start_time = time.time()
48
+ load_vllm_model()
49
+ results = []
50
+ sampling_params = SamplingParams(temperature=0.0,max_tokens=32)
51
+ # Load images
52
+ images = []
53
+ for file in files:
54
+ image_data = await file.read()
55
+ img = Image.open(BytesIO(image_data)).convert("RGB")
56
+ images.append(img)
57
+ for image in images:
58
+ inputs = {
59
+ "prompt": "ocr",
60
+ "multi_modal_data": {
61
+ "image": image
62
+ },
63
+ }
64
+ outputs = llm.generate(inputs, sampling_params)
65
+ for o in outputs:
66
+ generated_text = o.outputs[0].text
67
+ results.append(generated_text)
68
+
69
+ logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds")
70
+ return {"extracted_texts": results}
71
+ except Exception as e:
72
+ logger.error(f"Error in batch processing vLLM: {str(e)}")
73
+ return {"error": str(e)}
74
+
75
  # # main.py
76
  # from fastapi import FastAPI, File, UploadFile
77
  # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
 
137
  # if __name__ == "__main__":
138
  # import uvicorn
139
  # uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Global variables for model and processor
141
+ # model = None
142
+ # processor = None
143
+ # def load_model():
144
+ # """Load model and processor when needed"""
145
+ # global model, processor
146
+ # if model is None:
147
+ # model_id = "google/paligemma2-3b-mix-448"
148
+ # logger.info(f"Loading model {model_id}")
 
 
149
 
150
+ # # Load model with memory-efficient settings
151
+ # model = PaliGemmaForConditionalGeneration.from_pretrained(
152
+ # model_id,
153
+ # device_map="auto",
154
+ # torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency
155
+ # )
156
+ # processor = PaliGemmaProcessor.from_pretrained(model_id)
157
+ # logger.info("Model loaded successfully")
158
+ # def clean_memory():
159
+ # """Force garbage collection and clear CUDA cache"""
160
+ # gc.collect()
161
+ # if torch.cuda.is_available():
162
+ # torch.cuda.empty_cache()
163
+ # # Clear GPU cache
164
+ # torch.cuda.empty_cache()
165
+ # logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes")
166
+ # logger.info("Memory cleaned")
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # def predict(image):
169
+ # """Process a single image"""
170
+ # load_model() # Ensure model is loaded
171
 
172
+ # # Process input
173
+ # prompt = "<image> ocr"
174
+ # model_inputs = processor(text=prompt, images=image, return_tensors="pt")
175
 
176
+ # # Move to appropriate device
177
+ # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
178
 
179
+ # # Generate with memory optimization
180
+ # with torch.inference_mode():
181
+ # generation = model.generate(**model_inputs, max_new_tokens=200)
182
 
183
+ # # Decode output
184
+ # decoded = processor.decode(generation[0], skip_special_tokens=True)
185
 
186
+ # # Clean up intermediates
187
+ # del model_inputs, generation
188
+ # clean_memory()
189
+ # # del model,processor
190
+ # return decoded
191
 
192
+ # @app.post("/extract_text")
193
+ # async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
194
+ # """Extract text from a single image"""
195
+ # try:
196
+ # start_time = time.time()
197
+ # image = Image.open(BytesIO(await file.read())).convert("RGB")
198
+ # text = predict(image)
199
 
200
+ # # Schedule cleanup after response
201
+ # background_tasks.add_task(clean_memory)
202
 
203
+ # logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds")
204
+ # return {"extracted_text": text}
205
+ # except Exception as e:
206
+ # logger.error(f"Error processing image: {str(e)}")
207
+ # return {"error": str(e)}
208
+ # @app.post("/batch_extract_text")
209
+ # async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
210
+ # """Extract text from multiple images with batching"""
211
+ # try:
212
+ # start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ # # Limit batch size for memory management
215
+ # max_batch_size = 32 # Adjust based on your GPU memory
216
 
217
+ # # if len(files) > 32:
218
+ # # return {"error": "A maximum of 20 images can be processed at a time."}
219
 
220
+ # load_model() # Ensure model is loaded
221
 
222
+ # all_results = []
223
 
224
+ # # Process in smaller batches
225
+ # for i in range(0, len(files), max_batch_size):
226
+ # batch_files = files[i:i+max_batch_size]
227
 
228
+ # # Load images
229
+ # images = []
230
+ # for file in batch_files:
231
+ # image_data = await file.read()
232
+ # img = Image.open(BytesIO(image_data)).convert("RGB")
233
+ # images.append(img)
234
 
235
+ # # Create batch inputs
236
+ # prompts = ["<image> ocr"] * len(images)
237
+ # model_inputs = processor(text=prompts, images=images, return_tensors="pt")
238
 
239
+ # # Move to appropriate device
240
+ # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
241
 
242
+ # # Generate with memory optimization
243
+ # with torch.inference_mode():
244
+ # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
245
 
246
+ # # Decode outputs
247
+ # batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
248
+ # all_results.extend(batch_results)
249
 
250
+ # # Clean up batch resources
251
+ # del model_inputs, generations, images
252
+ # clean_memory()
253
 
254
+ # # Schedule cleanup after response
255
+ # background_tasks.add_task(clean_memory)
256
 
257
+ # logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds")
258
+ # return {"extracted_texts": all_results}
259
+ # except Exception as e:
260
+ # logger.error(f"Error in batch processing: {str(e)}")
261
+ # return {"error": str(e)}
262
 
263
 
264
  # Health check endpoint
265
+ # @app.get("/health")
266
+ # async def health_check():
267
+ # # Generate a random image (20x40 pixels) with random RGB values
268
+ # random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8)
269
 
270
+ # # Create an image from the random data
271
+ # image = Image.fromarray(random_data)
272
+ # predict(image)
273
+ # clean_memory()
274
+ # return {"status": "healthy"}
275
 
276
  # if __name__ == "__main__":
277
  # import uvicorn
requirements.txt CHANGED
@@ -3,7 +3,6 @@ uvicorn
3
  numpy
4
  huggingface_hub
5
  python-dotenv
6
- transformers
7
  torch
8
  accelerate
9
  pillow
 
3
  numpy
4
  huggingface_hub
5
  python-dotenv
 
6
  torch
7
  accelerate
8
  pillow