Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
f21911e
import os, json, base64
from io import BytesIO
from PIL import Image
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
LayoutLMv3Model,
T5ForConditionalGeneration,
AutoTokenizer
)
# ── 1) CONFIG & CHECKPOINT ────────────────────────────────────────────────
HF_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
CKPT_NAME = "pytorch_model.bin"
# 1a) Download the checkpoint dict from your Hub
ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
ckpt = torch.load(ckpt_path, map_location="cpu")
# ── 2) BUILD MODELS ───────────────────────────────────────────────────────
# 2a) Processor for LayoutLMv3
processor = AutoProcessor.from_pretrained(
"microsoft/layoutlmv3-base", apply_ocr=False
)
# 2b) LayoutLMv3 encoder
layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
layout_model.load_state_dict(ckpt["layout_model"], strict=False)
layout_model.eval().to("cpu")
# 2c) T5 decoder + tokenizer
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_model.load_state_dict(ckpt["t5_model"], strict=False)
t5_model.eval().to("cpu")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
# 2d) Projection head
proj_state = ckpt["projection"]
projection = torch.nn.Sequential(
torch.nn.Linear(768, t5_model.config.d_model),
torch.nn.LayerNorm(t5_model.config.d_model),
torch.nn.GELU()
)
projection.load_state_dict(proj_state)
projection.eval().to("cpu")
# 2e) Ensure we have a valid start token for generation
if t5_model.config.decoder_start_token_id is None:
t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
if t5_model.config.bos_token_id is None:
t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
# ── 3) INFERENCE ─────────────────────────────────────────────────────────
def infer(image_path, json_file):
img_name = os.path.basename(image_path)
# 3a) Read the uploaded NDJSON & find the matching record
entry = None
with open(json_file.name, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
if obj.get("img_name") == img_name:
entry = obj
break
if entry is None:
return f"❌ No JSON entry for: {img_name}"
words = entry["src_word_list"]
boxes = entry["src_wordbox_list"]
# 3b) Preprocess: image + OCR tokens + boxes
img = Image.open(image_path).convert("RGB")
enc = processor([img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True)
pixel_values = enc.pixel_values.to("cpu")
input_ids = enc.input_ids.to("cpu")
attention_mask = enc.attention_mask.to("cpu")
bbox = enc.bbox.to("cpu")
# 3c) Forward pass
with torch.no_grad():
out = layout_model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
bbox=bbox
)
seq_len = input_ids.size(1)
text_feats = out.last_hidden_state[:, :seq_len, :]
proj_feats = projection(text_feats)
gen_ids = t5_model.generate(
inputs_embeds=proj_feats,
attention_mask=attention_mask,
max_length=512,
decoder_start_token_id=t5_model.config.decoder_start_token_id
)
# 3d) Decode & return
return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
# ── 4) GRADIO APP ────────────────────────────────────────────────────────
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.File(label="Upload JSON (NDJSON)")
],
outputs="text",
title="OCR Reorder Pipeline"
)
if __name__ == "__main__":
demo.launch(share=True)