File size: 5,109 Bytes
5dbe551
97b296f
d352924
 
 
5dbe551
d352924
 
 
 
 
 
 
 
 
ade4954
d352924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7cc8ee
d352924
ade4954
d352924
ade4954
d352924
 
 
 
 
 
 
 
 
 
 
 
 
beecb06
d352924
 
 
 
 
ade4954
d352924
 
 
beecb06
d352924
 
ade4954
d352924
 
 
 
 
 
 
 
ade4954
d352924
 
ade4954
d352924
5dbe551
d352924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dbe551
d352924
c7cc8ee
d352924
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import re

# Load the model on CPU
def load_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "prithivMLmods/Qwen2-VL-OCR-2B-Instruct", 
        torch_dtype=torch.float32,
        device_map="cpu"
    )
    processor = AutoProcessor.from_pretrained("prithivMLmods/Qwen2-VL-OCR-2B-Instruct")
    return model, processor

# Function to extract medicine names
def extract_medicine_names(image):
    model, processor = load_model()
    
    # Prepare the message with the specific prompt for medicine extraction
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
            ],
        }
    ]
    
    # Prepare for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    # Remove <|im_end|> and any other special tokens that might appear in the output
    output_text = output_text.replace("<|im_end|>", "").strip()
    
    return output_text

# Create a singleton model and processor to avoid reloading for each request
model_instance = None
processor_instance = None

def get_model_and_processor():
    global model_instance, processor_instance
    if model_instance is None or processor_instance is None:
        model_instance, processor_instance = load_model()
    return model_instance, processor_instance

# Optimized extraction function that uses the singleton model
def extract_medicine_names_optimized(image):
    if image is None:
        return "Please upload an image."
    
    model, processor = get_model_and_processor()
    
    # Prepare the message with the specific prompt for medicine extraction
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": "Extract and list ONLY the names of medicines/drugs from this prescription image. Output the medicine names as a numbered list without any additional information or descriptions."},
            ],
        }
    ]
    
    # Prepare for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Generate output
    generated_ids = model.generate(**inputs, max_new_tokens=256)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    # Remove <|im_end|> and any other special tokens that might appear in the output
    output_text = output_text.replace("<|im_end|>", "").strip()
    
    return output_text

# Create Gradio interface
with gr.Blocks(title="Medicine Name Extractor") as app:
    gr.Markdown("# Medicine Name Extractor")
    gr.Markdown("Upload a medical prescription image to extract the names of medicines.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Upload Prescription Image")
            extract_btn = gr.Button("Extract Medicine Names", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Extracted Medicine Names", lines=10)
    
    extract_btn.click(
        fn=extract_medicine_names_optimized,
        inputs=input_image,
        outputs=output_text
    )
    
    gr.Markdown("### Notes")
    gr.Markdown("- This tool uses the Qwen2-VL-OCR model to extract text from prescription images")
    gr.Markdown("- For best results, ensure the prescription image is clear and readable")
    gr.Markdown("- Processing may take some time as the model runs on CPU")

# Launch the app
if __name__ == "__main__":
    app.launch()