ocr-reorder-space / inference.py
Uddipan Basu Bir
Download checkpoint from HF hub in OcrReorderPipeline
4956f20
import torch
from transformers import Pipeline
from PIL import Image
import base64
from io import BytesIO
from huggingface_hub import hf_hub_download
# HF model repo containing pytorch_model.bin with 'projection' state
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
class OcrReorderPipeline(Pipeline):
def __init__(self, model, tokenizer, processor, device=0):
super().__init__(model=model, tokenizer=tokenizer,
feature_extractor=processor, device=device)
# ── Download your fine-tuned checkpoint ───────────────────────────
ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
ckpt = torch.load(ckpt_path, map_location="cpu")
proj_state= ckpt["projection"]
# ── Rebuild & load your projection head (T5-small hidden size = 512) ─
d_model = 512
self.projection = torch.nn.Sequential(
torch.nn.Linear(768, d_model),
torch.nn.LayerNorm(d_model),
torch.nn.GELU()
)
self.projection.load_state_dict(proj_state)
self.projection.to(self.device)
def _sanitize_parameters(self, **kwargs):
# Extract only the custom args for preprocess; 'inputs' (the image) is passed positionally
words = kwargs.get("words", None)
boxes = kwargs.get("boxes", None)
return {"words": words, "boxes": boxes}, {}, {}
def preprocess(self, image, words, boxes):
# 'image' comes from the positional 'inputs' argument
data = base64.b64decode(image)
img = Image.open(BytesIO(data)).convert("RGB")
return self.feature_extractor(
[img], [words], boxes=[boxes],
return_tensors="pt", padding=True, truncation=True
)
def _forward(self, model_inputs):
pv, ids, mask, bbox = (
model_inputs[k].to(self.device)
for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
)
vision_out = self.model.vision_model(
pixel_values=pv,
input_ids=ids,
attention_mask=mask,
bbox=bbox
)
seq_len = ids.size(1)
text_feats = vision_out.last_hidden_state[:, :seq_len, :]
proj_feats = self.projection(text_feats)
gen_ids = self.model.text_model.generate(
inputs_embeds=proj_feats,
attention_mask=mask,
max_length=512
)
return {"generated_ids": gen_ids}
def postprocess(self, model_outputs):
return self.tokenizer.batch_decode(
model_outputs["generated_ids"],
skip_special_tokens=True
)