Spaces:
Running
Running
import torch | |
from transformers import Pipeline | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
from huggingface_hub import hf_hub_download | |
# HF model repo containing pytorch_model.bin with 'projection' state | |
HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small" | |
class OcrReorderPipeline(Pipeline): | |
def __init__(self, model, tokenizer, processor, device=0): | |
super().__init__(model=model, tokenizer=tokenizer, | |
feature_extractor=processor, device=device) | |
# ββ Download your fine-tuned checkpoint βββββββββββββββββββββββββββ | |
ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin") | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
proj_state= ckpt["projection"] | |
# ββ Rebuild & load your projection head (T5-small hidden size = 512) β | |
d_model = 512 | |
self.projection = torch.nn.Sequential( | |
torch.nn.Linear(768, d_model), | |
torch.nn.LayerNorm(d_model), | |
torch.nn.GELU() | |
) | |
self.projection.load_state_dict(proj_state) | |
self.projection.to(self.device) | |
def _sanitize_parameters(self, **kwargs): | |
# Extract only the custom args for preprocess; 'inputs' (the image) is passed positionally | |
words = kwargs.get("words", None) | |
boxes = kwargs.get("boxes", None) | |
return {"words": words, "boxes": boxes}, {}, {} | |
def preprocess(self, image, words, boxes): | |
# 'image' comes from the positional 'inputs' argument | |
data = base64.b64decode(image) | |
img = Image.open(BytesIO(data)).convert("RGB") | |
return self.feature_extractor( | |
[img], [words], boxes=[boxes], | |
return_tensors="pt", padding=True, truncation=True | |
) | |
def _forward(self, model_inputs): | |
pv, ids, mask, bbox = ( | |
model_inputs[k].to(self.device) | |
for k in ("pixel_values", "input_ids", "attention_mask", "bbox") | |
) | |
vision_out = self.model.vision_model( | |
pixel_values=pv, | |
input_ids=ids, | |
attention_mask=mask, | |
bbox=bbox | |
) | |
seq_len = ids.size(1) | |
text_feats = vision_out.last_hidden_state[:, :seq_len, :] | |
proj_feats = self.projection(text_feats) | |
gen_ids = self.model.text_model.generate( | |
inputs_embeds=proj_feats, | |
attention_mask=mask, | |
max_length=512 | |
) | |
return {"generated_ids": gen_ids} | |
def postprocess(self, model_outputs): | |
return self.tokenizer.batch_decode( | |
model_outputs["generated_ids"], | |
skip_special_tokens=True | |
) | |