Spaces:
Sleeping
Sleeping
import time | |
from io import BytesIO | |
import os | |
from dotenv import load_dotenv | |
from PIL import Image | |
import logging | |
from typing import List | |
from huggingface_hub import login | |
from fastapi import FastAPI, File, UploadFile | |
from vllm import LLM, SamplingParams | |
import torch | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Set the cache directory to a writable path | |
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" | |
token = os.getenv("huggingface_ankit") | |
# Login to the Hugging Face Hub | |
login(token) | |
app = FastAPI() | |
llm = None | |
def load_vllm_model(): | |
global llm | |
logger.info(f"Loading vLLM model...") | |
if llm is None: | |
llm = LLM( | |
model="google/paligemma2-3b-mix-448", | |
trust_remote_code=True, | |
max_model_len=4096, | |
dtype="float16", | |
) | |
async def batch_extract_text_vllm(files: List[UploadFile] = File(...)): | |
try: | |
start_time = time.time() | |
load_vllm_model() | |
results = [] | |
sampling_params = SamplingParams(temperature=0.0,max_tokens=32) | |
# Load images | |
images = [] | |
for file in files: | |
image_data = await file.read() | |
img = Image.open(BytesIO(image_data)).convert("RGB") | |
images.append(img) | |
for image in images: | |
inputs = { | |
"prompt": "ocr", | |
"multi_modal_data": { | |
"image": image | |
}, | |
} | |
outputs = llm.generate(inputs, sampling_params) | |
for o in outputs: | |
generated_text = o.outputs[0].text | |
results.append(generated_text) | |
logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds") | |
return {"extracted_texts": results} | |
except Exception as e: | |
logger.error(f"Error in batch processing vLLM: {str(e)}") | |
return {"error": str(e)} | |
# # main.py | |
# from fastapi import FastAPI, File, UploadFile | |
# from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration | |
# from transformers.image_utils import load_image | |
# import torch | |
# from io import BytesIO | |
# import os | |
# from dotenv import load_dotenv | |
# from PIL import Image | |
# from huggingface_hub import login | |
# # Load environment variables | |
# load_dotenv() | |
# # Set the cache directory to a writable path | |
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" | |
# token = os.getenv("huggingface_ankit") | |
# # Login to the Hugging Face Hub | |
# login(token) | |
# app = FastAPI() | |
# model_id = "google/paligemma2-3b-mix-448" | |
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') | |
# processor = PaliGemmaProcessor.from_pretrained(model_id) | |
# def predict(image): | |
# prompt = "<image> ocr" | |
# model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') | |
# input_len = model_inputs["input_ids"].shape[-1] | |
# with torch.inference_mode(): | |
# generation = model.generate(**model_inputs, max_new_tokens=200) | |
# torch.cuda.empty_cache() | |
# decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") | |
# return decoded | |
# @app.post("/extract_text") | |
# async def extract_text(file: UploadFile = File(...)): | |
# image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image | |
# text = predict(image) | |
# return {"extracted_text": text} | |
# @app.post("/batch_extract_text") | |
# async def batch_extract_text(files: list[UploadFile] = File(...)): | |
# # if len(files) > 20: | |
# # return {"error": "A maximum of 20 images can be processed at a time."} | |
# images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] | |
# prompts = ["OCR"] * len(images) | |
# model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
# input_len = model_inputs["input_ids"].shape[-1] | |
# with torch.inference_mode(): | |
# generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) | |
# torch.cuda.empty_cache() | |
# extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] | |
# return {"extracted_texts": extracted_texts} | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=7860) | |
# Global variables for model and processor | |
# model = None | |
# processor = None | |
# def load_model(): | |
# """Load model and processor when needed""" | |
# global model, processor | |
# if model is None: | |
# model_id = "google/paligemma2-3b-mix-448" | |
# logger.info(f"Loading model {model_id}") | |
# # Load model with memory-efficient settings | |
# model = PaliGemmaForConditionalGeneration.from_pretrained( | |
# model_id, | |
# device_map="auto", | |
# torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency | |
# ) | |
# processor = PaliGemmaProcessor.from_pretrained(model_id) | |
# logger.info("Model loaded successfully") | |
# def clean_memory(): | |
# """Force garbage collection and clear CUDA cache""" | |
# gc.collect() | |
# if torch.cuda.is_available(): | |
# torch.cuda.empty_cache() | |
# # Clear GPU cache | |
# torch.cuda.empty_cache() | |
# logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes") | |
# logger.info("Memory cleaned") | |
# def predict(image): | |
# """Process a single image""" | |
# load_model() # Ensure model is loaded | |
# # Process input | |
# prompt = "<image> ocr" | |
# model_inputs = processor(text=prompt, images=image, return_tensors="pt") | |
# # Move to appropriate device | |
# model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
# # Generate with memory optimization | |
# with torch.inference_mode(): | |
# generation = model.generate(**model_inputs, max_new_tokens=200) | |
# # Decode output | |
# decoded = processor.decode(generation[0], skip_special_tokens=True) | |
# # Clean up intermediates | |
# del model_inputs, generation | |
# clean_memory() | |
# # del model,processor | |
# return decoded | |
# @app.post("/extract_text") | |
# async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
# """Extract text from a single image""" | |
# try: | |
# start_time = time.time() | |
# image = Image.open(BytesIO(await file.read())).convert("RGB") | |
# text = predict(image) | |
# # Schedule cleanup after response | |
# background_tasks.add_task(clean_memory) | |
# logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds") | |
# return {"extracted_text": text} | |
# except Exception as e: | |
# logger.error(f"Error processing image: {str(e)}") | |
# return {"error": str(e)} | |
# @app.post("/batch_extract_text") | |
# async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): | |
# """Extract text from multiple images with batching""" | |
# try: | |
# start_time = time.time() | |
# # Limit batch size for memory management | |
# max_batch_size = 32 # Adjust based on your GPU memory | |
# # if len(files) > 32: | |
# # return {"error": "A maximum of 20 images can be processed at a time."} | |
# load_model() # Ensure model is loaded | |
# all_results = [] | |
# # Process in smaller batches | |
# for i in range(0, len(files), max_batch_size): | |
# batch_files = files[i:i+max_batch_size] | |
# # Load images | |
# images = [] | |
# for file in batch_files: | |
# image_data = await file.read() | |
# img = Image.open(BytesIO(image_data)).convert("RGB") | |
# images.append(img) | |
# # Create batch inputs | |
# prompts = ["<image> ocr"] * len(images) | |
# model_inputs = processor(text=prompts, images=images, return_tensors="pt") | |
# # Move to appropriate device | |
# model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} | |
# # Generate with memory optimization | |
# with torch.inference_mode(): | |
# generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) | |
# # Decode outputs | |
# batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] | |
# all_results.extend(batch_results) | |
# # Clean up batch resources | |
# del model_inputs, generations, images | |
# clean_memory() | |
# # Schedule cleanup after response | |
# background_tasks.add_task(clean_memory) | |
# logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds") | |
# return {"extracted_texts": all_results} | |
# except Exception as e: | |
# logger.error(f"Error in batch processing: {str(e)}") | |
# return {"error": str(e)} | |
# Health check endpoint | |
# @app.get("/health") | |
# async def health_check(): | |
# # Generate a random image (20x40 pixels) with random RGB values | |
# random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8) | |
# # Create an image from the random data | |
# image = Image.fromarray(random_data) | |
# predict(image) | |
# clean_memory() | |
# return {"status": "healthy"} | |
# if __name__ == "__main__": | |
# import uvicorn | |
# # Start the server with proper worker configuration | |
# uvicorn.run( | |
# app, | |
# host="0.0.0.0", | |
# port=7860, | |
# log_level="info", | |
# workers=1 # Multiple workers can cause GPU memory issues | |
# ) |