ocr_test_pali / main.py
Ankit Shrestha
Refactor and remove old endpoints
2e94917
import time
from io import BytesIO
import os
from dotenv import load_dotenv
from PIL import Image
import logging
from typing import List
from huggingface_hub import login
from fastapi import FastAPI, File, UploadFile
from vllm import LLM, SamplingParams
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Set the cache directory to a writable path
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
token = os.getenv("huggingface_ankit")
# Login to the Hugging Face Hub
login(token)
app = FastAPI()
llm = None
def load_vllm_model():
global llm
logger.info(f"Loading vLLM model...")
if llm is None:
llm = LLM(
model="google/paligemma2-3b-mix-448",
trust_remote_code=True,
max_model_len=4096,
dtype="float16",
)
@app.post("/batch_extract_text_vllm")
async def batch_extract_text_vllm(files: List[UploadFile] = File(...)):
try:
start_time = time.time()
load_vllm_model()
results = []
sampling_params = SamplingParams(temperature=0.0,max_tokens=32)
# Load images
images = []
for file in files:
image_data = await file.read()
img = Image.open(BytesIO(image_data)).convert("RGB")
images.append(img)
for image in images:
inputs = {
"prompt": "ocr",
"multi_modal_data": {
"image": image
},
}
outputs = llm.generate(inputs, sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
results.append(generated_text)
logger.info(f"vLLM Batch processing completed in {time.time() - start_time:.2f} seconds")
return {"extracted_texts": results}
except Exception as e:
logger.error(f"Error in batch processing vLLM: {str(e)}")
return {"error": str(e)}
# # main.py
# from fastapi import FastAPI, File, UploadFile
# from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
# from transformers.image_utils import load_image
# import torch
# from io import BytesIO
# import os
# from dotenv import load_dotenv
# from PIL import Image
# from huggingface_hub import login
# # Load environment variables
# load_dotenv()
# # Set the cache directory to a writable path
# os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
# token = os.getenv("huggingface_ankit")
# # Login to the Hugging Face Hub
# login(token)
# app = FastAPI()
# model_id = "google/paligemma2-3b-mix-448"
# model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to('cuda')
# processor = PaliGemmaProcessor.from_pretrained(model_id)
# def predict(image):
# prompt = "<image> ocr"
# model_inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda')
# input_len = model_inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generation = model.generate(**model_inputs, max_new_tokens=200)
# torch.cuda.empty_cache()
# decoded = processor.decode(generation[0], skip_special_tokens=True) #[len(prompt):].lstrip("\n")
# return decoded
# @app.post("/extract_text")
# async def extract_text(file: UploadFile = File(...)):
# image = Image.open(BytesIO(await file.read())).convert("RGB") # Ensure it's a valid PIL image
# text = predict(image)
# return {"extracted_text": text}
# @app.post("/batch_extract_text")
# async def batch_extract_text(files: list[UploadFile] = File(...)):
# # if len(files) > 20:
# # return {"error": "A maximum of 20 images can be processed at a time."}
# images = [Image.open(BytesIO(await file.read())).convert("RGB") for file in files]
# prompts = ["OCR"] * len(images)
# model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(torch.bfloat16).to(model.device)
# input_len = model_inputs["input_ids"].shape[-1]
# with torch.inference_mode():
# generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
# torch.cuda.empty_cache()
# extracted_texts = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
# return {"extracted_texts": extracted_texts}
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=7860)
# Global variables for model and processor
# model = None
# processor = None
# def load_model():
# """Load model and processor when needed"""
# global model, processor
# if model is None:
# model_id = "google/paligemma2-3b-mix-448"
# logger.info(f"Loading model {model_id}")
# # Load model with memory-efficient settings
# model = PaliGemmaForConditionalGeneration.from_pretrained(
# model_id,
# device_map="auto",
# torch_dtype=torch.bfloat16 # Use lower precision for memory efficiency
# )
# processor = PaliGemmaProcessor.from_pretrained(model_id)
# logger.info("Model loaded successfully")
# def clean_memory():
# """Force garbage collection and clear CUDA cache"""
# gc.collect()
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# # Clear GPU cache
# torch.cuda.empty_cache()
# logger.info(f"Memory allocated after clearing cache: {torch.cuda.memory_allocated()} bytes")
# logger.info("Memory cleaned")
# def predict(image):
# """Process a single image"""
# load_model() # Ensure model is loaded
# # Process input
# prompt = "<image> ocr"
# model_inputs = processor(text=prompt, images=image, return_tensors="pt")
# # Move to appropriate device
# model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
# # Generate with memory optimization
# with torch.inference_mode():
# generation = model.generate(**model_inputs, max_new_tokens=200)
# # Decode output
# decoded = processor.decode(generation[0], skip_special_tokens=True)
# # Clean up intermediates
# del model_inputs, generation
# clean_memory()
# # del model,processor
# return decoded
# @app.post("/extract_text")
# async def extract_text(background_tasks: BackgroundTasks, file: UploadFile = File(...)):
# """Extract text from a single image"""
# try:
# start_time = time.time()
# image = Image.open(BytesIO(await file.read())).convert("RGB")
# text = predict(image)
# # Schedule cleanup after response
# background_tasks.add_task(clean_memory)
# logger.info(f"Processing completed in {time.time() - start_time:.2f} seconds")
# return {"extracted_text": text}
# except Exception as e:
# logger.error(f"Error processing image: {str(e)}")
# return {"error": str(e)}
# @app.post("/batch_extract_text")
# async def batch_extract_text(batch_size:int, background_tasks: BackgroundTasks, files: List[UploadFile] = File(...)):
# """Extract text from multiple images with batching"""
# try:
# start_time = time.time()
# # Limit batch size for memory management
# max_batch_size = 32 # Adjust based on your GPU memory
# # if len(files) > 32:
# # return {"error": "A maximum of 20 images can be processed at a time."}
# load_model() # Ensure model is loaded
# all_results = []
# # Process in smaller batches
# for i in range(0, len(files), max_batch_size):
# batch_files = files[i:i+max_batch_size]
# # Load images
# images = []
# for file in batch_files:
# image_data = await file.read()
# img = Image.open(BytesIO(image_data)).convert("RGB")
# images.append(img)
# # Create batch inputs
# prompts = ["<image> ocr"] * len(images)
# model_inputs = processor(text=prompts, images=images, return_tensors="pt")
# # Move to appropriate device
# model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
# # Generate with memory optimization
# with torch.inference_mode():
# generations = model.generate(**model_inputs, max_new_tokens=200, do_sample=False)
# # Decode outputs
# batch_results = [processor.decode(generations[i], skip_special_tokens=True) for i in range(len(images))]
# all_results.extend(batch_results)
# # Clean up batch resources
# del model_inputs, generations, images
# clean_memory()
# # Schedule cleanup after response
# background_tasks.add_task(clean_memory)
# logger.info(f"Batch processing completed in {time.time() - start_time:.2f} seconds")
# return {"extracted_texts": all_results}
# except Exception as e:
# logger.error(f"Error in batch processing: {str(e)}")
# return {"error": str(e)}
# Health check endpoint
# @app.get("/health")
# async def health_check():
# # Generate a random image (20x40 pixels) with random RGB values
# random_data = np.random.randint(0, 256, (20, 40, 3), dtype=np.uint8)
# # Create an image from the random data
# image = Image.fromarray(random_data)
# predict(image)
# clean_memory()
# return {"status": "healthy"}
# if __name__ == "__main__":
# import uvicorn
# # Start the server with proper worker configuration
# uvicorn.run(
# app,
# host="0.0.0.0",
# port=7860,
# log_level="info",
# workers=1 # Multiple workers can cause GPU memory issues
# )