import torch from transformers import Pipeline from PIL import Image import base64 from io import BytesIO from huggingface_hub import hf_hub_download # HF model repo containing pytorch_model.bin with 'projection' state HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small" class OcrReorderPipeline(Pipeline): def __init__(self, model, tokenizer, processor, device=0): super().__init__(model=model, tokenizer=tokenizer, feature_extractor=processor, device=device) # ── Download your fine-tuned checkpoint ─────────────────────────── ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin") ckpt = torch.load(ckpt_path, map_location="cpu") proj_state= ckpt["projection"] # ── Rebuild & load your projection head (T5-small hidden size = 512) ─ d_model = 512 self.projection = torch.nn.Sequential( torch.nn.Linear(768, d_model), torch.nn.LayerNorm(d_model), torch.nn.GELU() ) self.projection.load_state_dict(proj_state) self.projection.to(self.device) def _sanitize_parameters(self, **kwargs): # Extract only the custom args for preprocess; 'inputs' (the image) is passed positionally words = kwargs.get("words", None) boxes = kwargs.get("boxes", None) return {"words": words, "boxes": boxes}, {}, {} def preprocess(self, image, words, boxes): # 'image' comes from the positional 'inputs' argument data = base64.b64decode(image) img = Image.open(BytesIO(data)).convert("RGB") return self.feature_extractor( [img], [words], boxes=[boxes], return_tensors="pt", padding=True, truncation=True ) def _forward(self, model_inputs): pv, ids, mask, bbox = ( model_inputs[k].to(self.device) for k in ("pixel_values", "input_ids", "attention_mask", "bbox") ) vision_out = self.model.vision_model( pixel_values=pv, input_ids=ids, attention_mask=mask, bbox=bbox ) seq_len = ids.size(1) text_feats = vision_out.last_hidden_state[:, :seq_len, :] proj_feats = self.projection(text_feats) gen_ids = self.model.text_model.generate( inputs_embeds=proj_feats, attention_mask=mask, max_length=512 ) return {"generated_ids": gen_ids} def postprocess(self, model_outputs): return self.tokenizer.batch_decode( model_outputs["generated_ids"], skip_special_tokens=True )