Spaces:
Running
Running
File size: 5,109 Bytes
5dbe551 97b296f d352924 5dbe551 d352924 ade4954 d352924 c7cc8ee d352924 ade4954 d352924 ade4954 d352924 beecb06 d352924 ade4954 d352924 beecb06 d352924 ade4954 d352924 ade4954 d352924 ade4954 d352924 5dbe551 d352924 5dbe551 d352924 c7cc8ee d352924 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import re
# Load the model on CPU
def load_model():
model = Qwen2VLForConditionalGeneration.from_pretrained(
"prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
torch_dtype=torch.float32,
device_map="cpu"
)
processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
return model, processor
# Function to extract medicine names
def extract_medicine_names(image):
model, processor = load_model()
# Prepare the message with the specific prompt for medicine extraction
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
],
}
]
# Prepare for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Remove <|im_end|> and any other special tokens that might appear in the output
output_text = output_text.replace("<|im_end|>", "").strip()
return output_text
# Create a singleton model and processor to avoid reloading for each request
model_instance = None
processor_instance = None
def get_model_and_processor():
global model_instance, processor_instance
if model_instance is None or processor_instance is None:
model_instance, processor_instance = load_model()
return model_instance, processor_instance
# Optimized extraction function that uses the singleton model
def extract_medicine_names_optimized(image):
if image is None:
return "Please upload an image."
model, processor = get_model_and_processor()
# Prepare the message with the specific prompt for medicine extraction
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
],
}
]
# Prepare for inference
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Generate output
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# Remove <|im_end|> and any other special tokens that might appear in the output
output_text = output_text.replace("<|im_end|>", "").strip()
return output_text
# Create Gradio interface
with gr.Blocks(title="Medicine Name Extractor") as app:
gr.Markdown("# Medicine Name Extractor")
gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Prescription Image")
extract_btn = gr.Button("Extract Medicine Names", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
extract_btn.click(
fn=extract_medicine_names_optimized,
inputs=input_image,
outputs=output_text
)
gr.Markdown("### Notes")
gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
gr.Markdown("- For best results, ensure the prescription image is clear and readable")
gr.Markdown("- Processing may take some time as the model runs on CPU")
# Launch the app
if __name__ == "__main__":
app.launch() |