import os, json, base64 from io import BytesIO from PIL import Image import gradio as gr import torch from huggingface_hub import hf_hub_download from transformers import ( AutoProcessor, LayoutLMv3Model, T5ForConditionalGeneration, AutoTokenizer ) # ── 1) CONFIG & CHECKPOINT ──────────────────────────────────────────────── HF_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small" CKPT_NAME = "pytorch_model.bin" # 1a) Download the checkpoint dict from your Hub ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME) ckpt = torch.load(ckpt_path, map_location="cpu") # ── 2) BUILD MODELS ─────────────────────────────────────────────────────── # 2a) Processor for LayoutLMv3 processor = AutoProcessor.from_pretrained( "microsoft/layoutlmv3-base", apply_ocr=False ) # 2b) LayoutLMv3 encoder layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base") layout_model.load_state_dict(ckpt["layout_model"], strict=False) layout_model.eval().to("cpu") # 2c) T5 decoder + tokenizer t5_model = T5ForConditionalGeneration.from_pretrained("t5-small") t5_model.load_state_dict(ckpt["t5_model"], strict=False) t5_model.eval().to("cpu") tokenizer = AutoTokenizer.from_pretrained("t5-small") # 2d) Projection head proj_state = ckpt["projection"] projection = torch.nn.Sequential( torch.nn.Linear(768, t5_model.config.d_model), torch.nn.LayerNorm(t5_model.config.d_model), torch.nn.GELU() ) projection.load_state_dict(proj_state) projection.eval().to("cpu") # 2e) Ensure we have a valid start token for generation if t5_model.config.decoder_start_token_id is None: t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id if t5_model.config.bos_token_id is None: t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id # ── 3) INFERENCE ───────────────────────────────────────────────────────── def infer(image_path, json_file): img_name = os.path.basename(image_path) # 3a) Read the uploaded NDJSON & find the matching record entry = None with open(json_file.name, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue obj = json.loads(line) if obj.get("img_name") == img_name: entry = obj break if entry is None: return f"❌ No JSON entry for: {img_name}" words = entry["src_word_list"] boxes = entry["src_wordbox_list"] # 3b) Preprocess: image + OCR tokens + boxes img = Image.open(image_path).convert("RGB") enc = processor([img], [words], boxes=[boxes], return_tensors="pt", padding=True, truncation=True) pixel_values = enc.pixel_values.to("cpu") input_ids = enc.input_ids.to("cpu") attention_mask = enc.attention_mask.to("cpu") bbox = enc.bbox.to("cpu") # 3c) Forward pass with torch.no_grad(): out = layout_model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, bbox=bbox ) seq_len = input_ids.size(1) text_feats = out.last_hidden_state[:, :seq_len, :] proj_feats = projection(text_feats) gen_ids = t5_model.generate( inputs_embeds=proj_feats, attention_mask=attention_mask, max_length=512, decoder_start_token_id=t5_model.config.decoder_start_token_id ) # 3d) Decode & return return tokenizer.decode(gen_ids[0], skip_special_tokens=True) # ── 4) GRADIO APP ──────────────────────────────────────────────────────── demo = gr.Interface( fn=infer, inputs=[ gr.Image(type="filepath", label="Upload Image"), gr.File(label="Upload JSON (NDJSON)") ], outputs="text", title="OCR Reorder Pipeline" ) if __name__ == "__main__": demo.launch(share=True)