Uddipan Basu Bir commited on
Commit
f21911e
Β·
1 Parent(s): a01cae7

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. app.py +60 -60
app.py CHANGED
@@ -1,6 +1,4 @@
1
- import os
2
- import json
3
- import base64
4
  from io import BytesIO
5
  from PIL import Image
6
  import gradio as gr
@@ -13,100 +11,102 @@ from transformers import (
13
  AutoTokenizer
14
  )
15
 
16
- # ── 1) MODEL SETUP ─────────────────────────────────────────────────────
17
- repo = "Uddipan107/ocr-layoutlmv3-base-t5-small"
 
18
 
19
- # Processor
20
- processor = AutoProcessor.from_pretrained(
21
- repo,
22
- subfolder="preprocessor",
23
- apply_ocr=False
24
- )
25
 
26
- # Encoder & Decoder
27
- layout_model = LayoutLMv3Model.from_pretrained(repo).to("cpu").eval()
28
- t5_model = T5ForConditionalGeneration.from_pretrained(repo).to("cpu").eval()
29
- tokenizer = AutoTokenizer.from_pretrained(
30
- repo, subfolder="preprocessor"
31
  )
32
 
33
- # Ensure decoder_start_token_id and bos_token_id are set
34
- if t5_model.config.decoder_start_token_id is None:
35
- fallback = tokenizer.bos_token_id or tokenizer.eos_token_id
36
- t5_model.config.decoder_start_token_id = fallback
37
- if t5_model.config.bos_token_id is None:
38
- t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
39
 
40
- # Projection head
41
- ckpt_file = hf_hub_download(repo_id=repo, filename="pytorch_model.bin")
42
- ckpt = torch.load(ckpt_file, map_location="cpu")
 
 
 
 
 
43
  proj_state = ckpt["projection"]
44
  projection = torch.nn.Sequential(
45
  torch.nn.Linear(768, t5_model.config.d_model),
46
  torch.nn.LayerNorm(t5_model.config.d_model),
47
  torch.nn.GELU()
48
- ).to("cpu")
49
  projection.load_state_dict(proj_state)
 
50
 
51
- # ── 2) INFERENCE FUNCTION ─────────────────────────────────────────────
 
 
 
 
 
 
52
  def infer(image_path, json_file):
53
  img_name = os.path.basename(image_path)
54
 
55
- # Load NDJSON
56
- data = []
57
  with open(json_file.name, "r", encoding="utf-8") as f:
58
  for line in f:
59
- if not line.strip():
 
60
  continue
61
- data.append(json.loads(line))
 
 
 
62
 
63
- entry = next((e for e in data if e.get("img_name") == img_name), None)
64
  if entry is None:
65
- return f"❌ No JSON entry found for image '{img_name}'"
66
 
67
- words = entry.get("src_word_list", [])
68
- boxes = entry.get("src_wordbox_list", [])
69
 
70
- # Preprocess image + tokens
71
  img = Image.open(image_path).convert("RGB")
72
- encoding = processor(
73
- [img], [words], boxes=[boxes],
74
- return_tensors="pt", padding=True, truncation=True
75
- )
76
- pixel_values = encoding.pixel_values.to("cpu")
77
- input_ids = encoding.input_ids.to("cpu")
78
- attention_mask = encoding.attention_mask.to("cpu")
79
- bbox = encoding.bbox.to("cpu")
80
-
81
- # Forward pass
82
  with torch.no_grad():
83
- # LayoutLMv3 encoding
84
- lm_out = layout_model(
85
  pixel_values=pixel_values,
86
  input_ids=input_ids,
87
  attention_mask=attention_mask,
88
  bbox=bbox
89
  )
90
  seq_len = input_ids.size(1)
91
- text_feats = lm_out.last_hidden_state[:, :seq_len, :]
92
-
93
- # Projection + T5 decoding
94
  proj_feats = projection(text_feats)
95
- gen_ids = t5_model.generate(
 
96
  inputs_embeds=proj_feats,
97
  attention_mask=attention_mask,
98
  max_length=512,
99
- decoder_start_token_id=t5_model.config.decoder_start_token_id,
100
- bos_token_id=t5_model.config.bos_token_id
101
  )
102
 
103
- # Decode and return
104
- result = tokenizer.batch_decode(
105
- gen_ids, skip_special_tokens=True
106
- )[0]
107
- return result
108
 
109
- # ── 3) GRADIO INTERFACE ────────────────────────────────────────────────
110
  demo = gr.Interface(
111
  fn=infer,
112
  inputs=[
 
1
+ import os, json, base64
 
 
2
  from io import BytesIO
3
  from PIL import Image
4
  import gradio as gr
 
11
  AutoTokenizer
12
  )
13
 
14
+ # ── 1) CONFIG & CHECKPOINT ────────────────────────────────────────────────
15
+ HF_REPO = "Uddipan107/ocr-layoutlmv3-base-t5-small"
16
+ CKPT_NAME = "pytorch_model.bin"
17
 
18
+ # 1a) Download the checkpoint dict from your Hub
19
+ ckpt_path = hf_hub_download(repo_id=HF_REPO, filename=CKPT_NAME)
20
+ ckpt = torch.load(ckpt_path, map_location="cpu")
 
 
 
21
 
22
+ # ── 2) BUILD MODELS ───────────────────────────────────────────────────────
23
+ # 2a) Processor for LayoutLMv3
24
+ processor = AutoProcessor.from_pretrained(
25
+ "microsoft/layoutlmv3-base", apply_ocr=False
 
26
  )
27
 
28
+ # 2b) LayoutLMv3 encoder
29
+ layout_model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base")
30
+ layout_model.load_state_dict(ckpt["layout_model"], strict=False)
31
+ layout_model.eval().to("cpu")
 
 
32
 
33
+ # 2c) T5 decoder + tokenizer
34
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
35
+ t5_model.load_state_dict(ckpt["t5_model"], strict=False)
36
+ t5_model.eval().to("cpu")
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained("t5-small")
39
+
40
+ # 2d) Projection head
41
  proj_state = ckpt["projection"]
42
  projection = torch.nn.Sequential(
43
  torch.nn.Linear(768, t5_model.config.d_model),
44
  torch.nn.LayerNorm(t5_model.config.d_model),
45
  torch.nn.GELU()
46
+ )
47
  projection.load_state_dict(proj_state)
48
+ projection.eval().to("cpu")
49
 
50
+ # 2e) Ensure we have a valid start token for generation
51
+ if t5_model.config.decoder_start_token_id is None:
52
+ t5_model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.pad_token_id
53
+ if t5_model.config.bos_token_id is None:
54
+ t5_model.config.bos_token_id = t5_model.config.decoder_start_token_id
55
+
56
+ # ── 3) INFERENCE ─────────────────────────────────────────────────────────
57
  def infer(image_path, json_file):
58
  img_name = os.path.basename(image_path)
59
 
60
+ # 3a) Read the uploaded NDJSON & find the matching record
61
+ entry = None
62
  with open(json_file.name, "r", encoding="utf-8") as f:
63
  for line in f:
64
+ line = line.strip()
65
+ if not line:
66
  continue
67
+ obj = json.loads(line)
68
+ if obj.get("img_name") == img_name:
69
+ entry = obj
70
+ break
71
 
 
72
  if entry is None:
73
+ return f"❌ No JSON entry for: {img_name}"
74
 
75
+ words = entry["src_word_list"]
76
+ boxes = entry["src_wordbox_list"]
77
 
78
+ # 3b) Preprocess: image + OCR tokens + boxes
79
  img = Image.open(image_path).convert("RGB")
80
+ enc = processor([img], [words], boxes=[boxes],
81
+ return_tensors="pt", padding=True, truncation=True)
82
+ pixel_values = enc.pixel_values.to("cpu")
83
+ input_ids = enc.input_ids.to("cpu")
84
+ attention_mask = enc.attention_mask.to("cpu")
85
+ bbox = enc.bbox.to("cpu")
86
+
87
+ # 3c) Forward pass
 
 
88
  with torch.no_grad():
89
+ out = layout_model(
 
90
  pixel_values=pixel_values,
91
  input_ids=input_ids,
92
  attention_mask=attention_mask,
93
  bbox=bbox
94
  )
95
  seq_len = input_ids.size(1)
96
+ text_feats = out.last_hidden_state[:, :seq_len, :]
 
 
97
  proj_feats = projection(text_feats)
98
+
99
+ gen_ids = t5_model.generate(
100
  inputs_embeds=proj_feats,
101
  attention_mask=attention_mask,
102
  max_length=512,
103
+ decoder_start_token_id=t5_model.config.decoder_start_token_id
 
104
  )
105
 
106
+ # 3d) Decode & return
107
+ return tokenizer.decode(gen_ids[0], skip_special_tokens=True)
 
 
 
108
 
109
+ # ── 4) GRADIO APP ────────────────────────────────────────────────────────
110
  demo = gr.Interface(
111
  fn=infer,
112
  inputs=[