Uddipan Basu Bir commited on
Commit
0cfc73f
Β·
1 Parent(s): ab9088f

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. inference.py +15 -2
inference.py CHANGED
@@ -3,12 +3,22 @@ from transformers import Pipeline
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
 
 
 
 
6
 
7
  class OcrReorderPipeline(Pipeline):
8
  def __init__(self, model, tokenizer, processor, device=0):
9
  super().__init__(model=model, tokenizer=tokenizer,
10
  feature_extractor=processor, device=device)
11
- proj_state = torch.load("pytorch_model.bin", map_location="cpu")["projection"]
 
 
 
 
 
 
12
  self.projection = torch.nn.Sequential(
13
  torch.nn.Linear(768, model.config.d_model),
14
  torch.nn.LayerNorm(model.config.d_model),
@@ -31,17 +41,20 @@ class OcrReorderPipeline(Pipeline):
31
  def _forward(self, model_inputs):
32
  pv, ids, mask, bbox = (
33
  model_inputs[k].to(self.device)
34
- for k in ("pixel_values","input_ids","attention_mask","bbox")
35
  )
 
36
  vision_out = self.model.vision_model(
37
  pixel_values=pv,
38
  input_ids=ids,
39
  attention_mask=mask,
40
  bbox=bbox
41
  )
 
42
  seq_len = ids.size(1)
43
  text_feats = vision_out.last_hidden_state[:, :seq_len, :]
44
  proj_feats = self.projection(text_feats)
 
45
  gen_ids = self.model.text_model.generate(
46
  inputs_embeds=proj_feats,
47
  attention_mask=mask,
 
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # point at your HF model repo
9
+ HF_MODEL_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
10
 
11
  class OcrReorderPipeline(Pipeline):
12
  def __init__(self, model, tokenizer, processor, device=0):
13
  super().__init__(model=model, tokenizer=tokenizer,
14
  feature_extractor=processor, device=device)
15
+
16
+ # ── Download your fine-tuned checkpoint ───────────────────────────
17
+ ckpt_path = hf_hub_download(repo_id=HF_MODEL_REPO, filename="pytorch_model.bin")
18
+ ckpt = torch.load(ckpt_path, map_location="cpu")
19
+ proj_state= ckpt["projection"]
20
+
21
+ # ── Rebuild & load your projection head ────────────────────────────
22
  self.projection = torch.nn.Sequential(
23
  torch.nn.Linear(768, model.config.d_model),
24
  torch.nn.LayerNorm(model.config.d_model),
 
41
  def _forward(self, model_inputs):
42
  pv, ids, mask, bbox = (
43
  model_inputs[k].to(self.device)
44
+ for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
45
  )
46
+
47
  vision_out = self.model.vision_model(
48
  pixel_values=pv,
49
  input_ids=ids,
50
  attention_mask=mask,
51
  bbox=bbox
52
  )
53
+
54
  seq_len = ids.size(1)
55
  text_feats = vision_out.last_hidden_state[:, :seq_len, :]
56
  proj_feats = self.projection(text_feats)
57
+
58
  gen_ids = self.model.text_model.generate(
59
  inputs_embeds=proj_feats,
60
  attention_mask=mask,