import time from io import BytesIO import os from dotenv import load_dotenv from PIL import Image import logging from typing import List from huggingface_hub import login from fastapi import FastAPI, File, UploadFile from vllm import LLM, SamplingParams import torch import torch._dynamo torch._dynamo.config.suppress_errors = True # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Load environment variables load_dotenv() # Set the cache directory to a writable path os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" token = os.getenv("huggingface_ankit") # Login to the Hugging Face Hub login(token) app = FastAPI() llm = None def load_vllm_model(): global llm logger.info(f"Loading vLLM model...") if llm is None: llm = LLM( model="google/paligemma2-3b-mix-448", trust_remote_code=True, max_model_len=4096, dtype="float16", ) @app.post("/batch_extract_text_vllm") async def batch_extract_text_vllm(files: List[UploadFile] = File(...)): try: start_time = time.time() load_vllm_model() results = [] sampling_params = SamplingParams(temperature=0.0,max_tokens=32) # Load images images = [] for file in files: image_data = await file.read() img = Image.open(BytesIO(image_data)).convert("RGB") images.append(img) for image in images: inputs = { "prompt": "ocr", "multi_modal_data": { "image": image }, } outputs = llm.generate(inputs, sampling_params) for o in outputs: generated_text = o.outputs[0].text results.append(generated_text) logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds") return {"extracted_texts": results} except Exception as e: logger.error(f"Error in batch processing vLLM: {str(e)}") return {"error": str(e)} # # main.py # from fastapi import FastAPI, File, UploadFile # from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration # from transformers.image_utils import load_image # import torch # from io import BytesIO # import os # from dotenv import load_dotenv # from PIL import Image # from huggingface_hub import login # # Load environment variables # load_dotenv() # # Set the cache directory to a writable path # os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache" # token = os.getenv("huggingface_ankit") # # Login to the Hugging Face Hub # login(token) # app = FastAPI() # model_id = "google/paligemma2-3b-mix-448" # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda') # processor = PaliGemmaProcessor.from_pretrained(model_id) # def predict(image): # prompt = " ocr" # model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda') # input_len = model_inputs["input_ids"].shape[-1] # with torch.inference_mode(): # generation = model.generate(**model_inputs, max_new_tokens=200) # torch.cuda.empty_cache() # decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n") # return decoded # @app.post("/extract_text") # async def extract_text(file: UploadFile = File(...)): # image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image # text = predict(image) # return {"extracted_text": text} # @app.post("/batch_extract_text") # async def batch_extract_text(files: list[UploadFile] = File(...)): # # if len(files) > 20: # # return {"error": "A maximum of 20 images can be processed at a time."} # images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files] # prompts = ["OCR"] * len(images) # model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device) # input_len = model_inputs["input_ids"].shape[-1] # with torch.inference_mode(): # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) # torch.cuda.empty_cache() # extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] # return {"extracted_texts": extracted_texts} # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=7860) # Global variables for model and processor # model = None # processor = None # def load_model(): # """Load model and processor when needed""" # global model, processor # if model is None: # model_id = "google/paligemma2-3b-mix-448" # logger.info(f"Loading model {model_id}") # # Load model with memory-efficient settings # model = PaliGemmaForConditionalGeneration.from_pretrained( # model_id, # device_map="auto", # torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency # ) # processor = PaliGemmaProcessor.from_pretrained(model_id) # logger.info("Model loaded successfully") # def clean_memory(): # """Force garbage collection and clear CUDA cache""" # gc.collect() # if torch.cuda.is_available(): # torch.cuda.empty_cache() # # Clear GPU cache # torch.cuda.empty_cache() # logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes") # logger.info("Memory cleaned") # def predict(image): # """Process a single image""" # load_model() # Ensure model is loaded # # Process input # prompt = " ocr" # model_inputs = processor(text=prompt, images=image, return_tensors="pt") # # Move to appropriate device # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} # # Generate with memory optimization # with torch.inference_mode(): # generation = model.generate(**model_inputs, max_new_tokens=200) # # Decode output # decoded = processor.decode(generation[0], skip_special_tokens=True) # # Clean up intermediates # del model_inputs, generation # clean_memory() # # del model,processor # return decoded # @app.post("/extract_text") # async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)): # """Extract text from a single image""" # try: # start_time = time.time() # image = Image.open(BytesIO(await file.read())).convert("RGB") # text = predict(image) # # Schedule cleanup after response # background_tasks.add_task(clean_memory) # logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds") # return {"extracted_text": text} # except Exception as e: # logger.error(f"Error processing image: {str(e)}") # return {"error": str(e)} # @app.post("/batch_extract_text") # async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)): # """Extract text from multiple images with batching""" # try: # start_time = time.time() # # Limit batch size for memory management # max_batch_size = 32 # Adjust based on your GPU memory # # if len(files) > 32: # # return {"error": "A maximum of 20 images can be processed at a time."} # load_model() # Ensure model is loaded # all_results = [] # # Process in smaller batches # for i in range(0, len(files), max_batch_size): # batch_files = files[i:i+max_batch_size] # # Load images # images = [] # for file in batch_files: # image_data = await file.read() # img = Image.open(BytesIO(image_data)).convert("RGB") # images.append(img) # # Create batch inputs # prompts = [" ocr"] * len(images) # model_inputs = processor(text=prompts, images=images, return_tensors="pt") # # Move to appropriate device # model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()} # # Generate with memory optimization # with torch.inference_mode(): # generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False) # # Decode outputs # batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))] # all_results.extend(batch_results) # # Clean up batch resources # del model_inputs, generations, images # clean_memory() # # Schedule cleanup after response # background_tasks.add_task(clean_memory) # logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds") # return {"extracted_texts": all_results} # except Exception as e: # logger.error(f"Error in batch processing: {str(e)}") # return {"error": str(e)} # Health check endpoint # @app.get("/health") # async def health_check(): # # Generate a random image (20x40 pixels) with random RGB values # random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8) # # Create an image from the random data # image = Image.fromarray(random_data) # predict(image) # clean_memory() # return {"status": "healthy"} # if __name__ == "__main__": # import uvicorn # # Start the server with proper worker configuration # uvicorn.run( # app, # host="0.0.0.0", # port=7860, # log_level="info", # workers=1 # Multiple workers can cause GPU memory issues # )