Spaces:
Running
Running
import os, json, base64 | |
from io import BytesIO | |
from PIL import Image | |
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
from transformers import ( | |
AutoProcessor, | |
LayoutLMv3Model, | |
T5ForConditionalGeneration, | |
AutoTokenizer | |
) | |
# ββ 1) CONFIG & CHECKPOINT ββββββββββββββββββββββββββββββββββββββββββββββββ | |
HF_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small" | |
CKPT_NAME = "pytorch_model.bin" | |
# 1a) Download the checkpoint dict from your Hub | |
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME) | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
# ββ 2) BUILD MODELS βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2a) Processor for LayoutLMv3 | |
processor = AutoProcessor.from_pretrained( | |
"microsoft/layoutlmv3-base", apply_ocr=False | |
) | |
# 2b) LayoutLMv3 encoder | |
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base") | |
layout_model.load_state_dict(ckpt["layout_model"], strict=False) | |
layout_model.eval().to("cpu") | |
# 2c) T5 decoder + tokenizer | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") | |
t5_model.load_state_dict(ckpt["t5_model"], strict=False) | |
t5_model.eval().to("cpu") | |
tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
# 2d) Projection head | |
proj_state = ckpt["projection"] | |
projection = torch.nn.Sequential( | |
torch.nn.Linear(768, t5_model.config.d_model), | |
torch.nn.LayerNorm(t5_model.config.d_model), | |
torch.nn.GELU() | |
) | |
projection.load_state_dict(proj_state) | |
projection.eval().to("cpu") | |
# 2e) Ensure we have a valid start token for generation | |
if t5_model.config.decoder_start_token_id is None: | |
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id | |
if t5_model.config.bos_token_id is None: | |
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id | |
# ββ 3) INFERENCE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def infer(image_path, json_file): | |
img_name = os.path.basename(image_path) | |
# 3a) Read the uploaded NDJSON & find the matching record | |
entry = None | |
with open(json_file.name, "r", encoding="utf-8") as f: | |
for line in f: | |
line = line.strip() | |
if not line: | |
continue | |
obj = json.loads(line) | |
if obj.get("img_name") == img_name: | |
entry = obj | |
break | |
if entry is None: | |
return f"β No JSON entry for: {img_name}" | |
words = entry["src_word_list"] | |
boxes = entry["src_wordbox_list"] | |
# 3b) Preprocess: image + OCR tokens + boxes | |
img = Image.open(image_path).convert("RGB") | |
enc = processor([img], [words], boxes=[boxes], | |
return_tensors="pt", padding=True, truncation=True) | |
pixel_values = enc.pixel_values.to("cpu") | |
input_ids = enc.input_ids.to("cpu") | |
attention_mask = enc.attention_mask.to("cpu") | |
bbox = enc.bbox.to("cpu") | |
# 3c) Forward pass | |
with torch.no_grad(): | |
out = layout_model( | |
pixel_values=pixel_values, | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
bbox=bbox | |
) | |
seq_len = input_ids.size(1) | |
text_feats = out.last_hidden_state[:, :seq_len, :] | |
proj_feats = projection(text_feats) | |
gen_ids = t5_model.generate( | |
inputs_embeds=proj_feats, | |
attention_mask=attention_mask, | |
max_length=512, | |
decoder_start_token_id=t5_model.config.decoder_start_token_id | |
) | |
# 3d) Decode & return | |
return tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
# ββ 4) GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload Image"), | |
gr.File(label="Upload JSON (NDJSON)") | |
], | |
outputs="text", | |
title="OCR Reorder Pipeline" | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |