ai42 commited on
Commit
e0a9079
·
1 Parent(s): 0102fad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +433 -10
app.py CHANGED
@@ -1,10 +1,433 @@
1
- The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
2
- Moving 0 files to the new cache system
3
-
4
- 0it [00:00, ?it/s]
5
- 0it [00:00, ?it/s]
6
- image-classification is already registered. Overwriting pipeline for task image-classification...
7
- Traceback (most recent call last):
8
- File "/home/user/app/app.py", line 14, in <module>
9
- from docquery.document import load_document, ImageDocumenta
10
- ImportError: cannot import name 'ImageDocumenta' from 'docquery.document' (/home/user/.local/lib/python3.10/site-packages/docquery/document.py)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
+
5
+ from PIL import Image, ImageDraw
6
+ import traceback
7
+
8
+ import gradio as gr
9
+
10
+ import torch
11
+ from docquery import pipeline
12
+ from docquery.document import load_document, ImageDocument
13
+ from docquery.ocr_reader import get_ocr_reader
14
+
15
+
16
+ def ensure_list(x):
17
+ if isinstance(x, list):
18
+ return x
19
+ else:
20
+ return [x]
21
+
22
+
23
+ CHECKPOINTS = {
24
+ "LayoutLMv1 🦉": "impira/layoutlm-document-qa",
25
+ "LayoutLMv1 for Invoices 💸": "impira/layoutlm-invoices",
26
+ "Donut 🍩": "naver-clova-ix/donut-base-finetuned-docvqa",
27
+ "ggml-Vicuna": "eachadea/ggml-vicuna-13b-1.1",
28
+ }
29
+
30
+ PIPELINES = {}
31
+
32
+
33
+ def construct_pipeline(task, model):
34
+ global PIPELINES
35
+ if model in PIPELINES:
36
+ return PIPELINES[model]
37
+
38
+ device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
40
+ PIPELINES[model] = ret
41
+ return ret
42
+
43
+
44
+ def run_pipeline(model, question, document, top_k):
45
+ pipeline = construct_pipeline("document-question-answering", model)
46
+ return pipeline(question=question, **document.context, top_k=top_k)
47
+
48
+
49
+ # TODO: Move into docquery
50
+ # TODO: Support words past the first page (or window?)
51
+ def lift_word_boxes(document, page):
52
+ return document.context["image"][page][1]
53
+
54
+
55
+ def expand_bbox(word_boxes):
56
+ if len(word_boxes) == 0:
57
+ return None
58
+
59
+ min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
60
+ min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
61
+ return [min_x, min_y, max_x, max_y]
62
+
63
+
64
+ # LayoutLM boxes are normalized to 0, 1000
65
+ def normalize_bbox(box, width, height, padding=0.005):
66
+ min_x, min_y, max_x, max_y = [c / 1000 for c in box]
67
+ if padding != 0:
68
+ min_x = max(0, min_x - padding)
69
+ min_y = max(0, min_y - padding)
70
+ max_x = min(max_x + padding, 1)
71
+ max_y = min(max_y + padding, 1)
72
+ return [min_x * width, min_y * height, max_x * width, max_y * height]
73
+
74
+
75
+ examples = [
76
+ [
77
+ "invoice.png",
78
+ "What is the invoice number?",
79
+ ],
80
+ [
81
+ "contract.jpeg",
82
+ "What is the purchase amount?",
83
+ ],
84
+ [
85
+ "statement.png",
86
+ "What are net sales for 2020?",
87
+ ],
88
+ [
89
+ "SaleData.xlsx",
90
+
91
+ ],
92
+ # [
93
+ # "docquery.png",
94
+ # "How many likes does the space have?",
95
+ # ],
96
+ # [
97
+ # "hacker_news.png",
98
+ # "What is the title of post number 5?",
99
+ # ],
100
+ ]
101
+
102
+ question_files = {
103
+ "What are net sales for 2020?": "statement.pdf",
104
+ "How many likes does the space have?": "https://huggingface.co/spaces/impira/docquery",
105
+ "What is the title of post number 5?": "https://news.ycombinator.com",
106
+ }
107
+
108
+
109
+ def process_path(path):
110
+ error = None
111
+ if path:
112
+ try:
113
+ document = load_document(path)
114
+ return (
115
+ document,
116
+ gr.update(visible=True, value=document.preview),
117
+ gr.update(visible=True),
118
+ gr.update(visible=False, value=None),
119
+ gr.update(visible=False, value=None),
120
+ None,
121
+ )
122
+ except Exception as e:
123
+ traceback.print_exc()
124
+ error = str(e)
125
+ return (
126
+ None,
127
+ gr.update(visible=False, value=None),
128
+ gr.update(visible=False),
129
+ gr.update(visible=False, value=None),
130
+ gr.update(visible=False, value=None),
131
+ gr.update(visible=True, value=error) if error is not None else None,
132
+ None,
133
+ )
134
+
135
+
136
+ def process_upload(file):
137
+ if file:
138
+ return process_path(file.name)
139
+ else:
140
+ return (
141
+ None,
142
+ gr.update(visible=False, value=None),
143
+ gr.update(visible=False),
144
+ gr.update(visible=False, value=None),
145
+ gr.update(visible=False, value=None),
146
+ None,
147
+ )
148
+
149
+
150
+ colors = ["#64A087", "green", "black"]
151
+
152
+
153
+ def process_question(question, document, model=list(CHECKPOINTS.keys())[0]):
154
+ if not question or document is None:
155
+ return None, None, None
156
+
157
+ text_value = None
158
+ predictions = run_pipeline(model, question, document, 3)
159
+ pages = [x.copy().convert("RGB") for x in document.preview]
160
+ for i, p in enumerate(ensure_list(predictions)):
161
+ if i == 0:
162
+ text_value = p["answer"]
163
+ else:
164
+ # Keep the code around to produce multiple boxes, but only show the top
165
+ # prediction for now
166
+ break
167
+
168
+ if "word_ids" in p:
169
+ image = pages[p["page"]]
170
+ draw = ImageDraw.Draw(image, "RGBA")
171
+ word_boxes = lift_word_boxes(document, p["page"])
172
+ x1, y1, x2, y2 = normalize_bbox(
173
+ expand_bbox([word_boxes[i] for i in p["word_ids"]]),
174
+ image.width,
175
+ image.height,
176
+ )
177
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
178
+
179
+ return (
180
+ gr.update(visible=True, value=pages),
181
+ gr.update(visible=True, value=predictions),
182
+ gr.update(
183
+ visible=True,
184
+ value=text_value,
185
+ ),
186
+ )
187
+
188
+
189
+ def load_example_document(img, question, model):
190
+ if img is not None:
191
+ if question in question_files:
192
+ document = load_document(question_files[question])
193
+ else:
194
+ document = ImageDocument(Image.fromarray(img), get_ocr_reader())
195
+ preview, answer, answer_text = process_question(question, document, model)
196
+ return document, question, preview, gr.update(visible=True), answer, answer_text
197
+ else:
198
+ return None, None, None, gr.update(visible=False), None, None
199
+
200
+
201
+ CSS = """
202
+ #question input {
203
+ font-size: 16px;
204
+ }
205
+ #url-textbox {
206
+ padding: 0 !important;
207
+ }
208
+ #short-upload-box .w-full {
209
+ min-height: 10rem !important;
210
+ }
211
+ /* I think something like this can be used to re-shape
212
+ * the table
213
+ */
214
+ /*
215
+ .gr-samples-table tr {
216
+ display: inline;
217
+ }
218
+ .gr-samples-table .p-2 {
219
+ width: 100px;
220
+ }
221
+ */
222
+ #select-a-file {
223
+ width: 100%;
224
+ }
225
+ #file-clear {
226
+ padding-top: 2px !important;
227
+ padding-bottom: 2px !important;
228
+ padding-left: 8px !important;
229
+ padding-right: 8px !important;
230
+ margin-top: 10px;
231
+ }
232
+ .gradio-container .gr-button-primary {
233
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
234
+ border: 1px solid #B0DCCC;
235
+ border-radius: 8px;
236
+ color: #1B8700;
237
+ }
238
+ .gradio-container.dark button#submit-button {
239
+ background: linear-gradient(180deg, #CDF9BE 0%, #AFF497 100%);
240
+ border: 1px solid #B0DCCC;
241
+ border-radius: 8px;
242
+ color: #1B8700
243
+ }
244
+
245
+ table.gr-samples-table tr td {
246
+ border: none;
247
+ outline: none;
248
+ }
249
+
250
+ table.gr-samples-table tr td:first-of-type {
251
+ width: 0%;
252
+ }
253
+
254
+ div#short-upload-box div.absolute {
255
+ display: none !important;
256
+ }
257
+
258
+ gradio-app > div > div > div > div.w-full > div, .gradio-app > div > div > div > div.w-full > div {
259
+ gap: 0px 2%;
260
+ }
261
+
262
+ gradio-app div div div div.w-full, .gradio-app div div div div.w-full {
263
+ gap: 0px;
264
+ }
265
+
266
+ gradio-app h2, .gradio-app h2 {
267
+ padding-top: 10px;
268
+ }
269
+
270
+ #answer {
271
+ overflow-y: scroll;
272
+ color: white;
273
+ background: #666;
274
+ border-color: #666;
275
+ font-size: 20px;
276
+ font-weight: bold;
277
+ }
278
+
279
+ #answer span {
280
+ color: white;
281
+ }
282
+
283
+ #answer textarea {
284
+ color:white;
285
+ background: #777;
286
+ border-color: #777;
287
+ font-size: 18px;
288
+ }
289
+
290
+ #url-error input {
291
+ color: red;
292
+ }
293
+ """
294
+
295
+ with gr.Blocks(css=CSS) as demo:
296
+ gr.Markdown("# DocQuery: Document Query Engine")
297
+ gr.Markdown(
298
+ "DocQuery (created by [Impira](https://impira.com?utm_source=huggingface&utm_medium=referral&utm_campaign=docquery_space))"
299
+ " uses LayoutLMv1 fine-tuned on DocVQA, a document visual question"
300
+ " answering dataset, as well as SQuAD, which boosts its English-language comprehension."
301
+ " To use it, simply upload an image or PDF, type a question, and click 'submit', or "
302
+ " click one of the examples to load them."
303
+ " DocQuery is MIT-licensed and available on [Github](https://github.com/impira/docquery)."
304
+ )
305
+
306
+ document = gr.Variable()
307
+ example_question = gr.Textbox(visible=False)
308
+ example_image = gr.Image(visible=False)
309
+
310
+ with gr.Row(equal_height=True):
311
+ with gr.Column():
312
+ with gr.Row():
313
+ gr.Markdown("## 1. Select a file", elem_id="select-a-file")
314
+ img_clear_button = gr.Button(
315
+ "Clear", variant="secondary", elem_id="file-clear", visible=False
316
+ )
317
+ image = gr.Gallery(visible=False)
318
+ with gr.Row(equal_height=True):
319
+ with gr.Column():
320
+ with gr.Row():
321
+ url = gr.Textbox(
322
+ show_label=False,
323
+ placeholder="URL",
324
+ lines=1,
325
+ max_lines=1,
326
+ elem_id="url-textbox",
327
+ )
328
+ submit = gr.Button("Get")
329
+ url_error = gr.Textbox(
330
+ visible=False,
331
+ elem_id="url-error",
332
+ max_lines=1,
333
+ interactive=False,
334
+ label="Error",
335
+ )
336
+ gr.Markdown("— or —")
337
+ upload = gr.File(label=None, interactive=True, elem_id="short-upload-box")
338
+ gr.Examples(
339
+ examples=examples,
340
+ inputs=[example_image, example_question],
341
+ )
342
+
343
+ with gr.Column() as col:
344
+ gr.Markdown("## 2. Ask a question")
345
+ question = gr.Textbox(
346
+ label="Question",
347
+ placeholder="e.g. What is the invoice number?",
348
+ lines=1,
349
+ max_lines=1,
350
+ )
351
+ model = gr.Radio(
352
+ choices=list(CHECKPOINTS.keys()),
353
+ value=list(CHECKPOINTS.keys())[0],
354
+ label="Model",
355
+ )
356
+
357
+ with gr.Row():
358
+ clear_button = gr.Button("Clear", variant="secondary")
359
+ submit_button = gr.Button(
360
+ "Submit", variant="primary", elem_id="submit-button"
361
+ )
362
+ with gr.Column():
363
+ output_text = gr.Textbox(
364
+ label="Top Answer", visible=False, elem_id="answer"
365
+ )
366
+ output = gr.JSON(label="Output", visible=False)
367
+
368
+ for cb in [img_clear_button, clear_button]:
369
+ cb.click(
370
+ lambda _: (
371
+ gr.update(visible=False, value=None),
372
+ None,
373
+ gr.update(visible=False, value=None),
374
+ gr.update(visible=False, value=None),
375
+ gr.update(visible=False),
376
+ None,
377
+ None,
378
+ None,
379
+ gr.update(visible=False, value=None),
380
+ None,
381
+ ),
382
+ inputs=clear_button,
383
+ outputs=[
384
+ image,
385
+ document,
386
+ output,
387
+ output_text,
388
+ img_clear_button,
389
+ example_image,
390
+ upload,
391
+ url,
392
+ url_error,
393
+ question,
394
+ ],
395
+ )
396
+
397
+ upload.change(
398
+ fn=process_upload,
399
+ inputs=[upload],
400
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
401
+ )
402
+ submit.click(
403
+ fn=process_path,
404
+ inputs=[url],
405
+ outputs=[document, image, img_clear_button, output, output_text, url_error],
406
+ )
407
+
408
+ question.submit(
409
+ fn=process_question,
410
+ inputs=[question, document, model],
411
+ outputs=[image, output, output_text],
412
+ )
413
+
414
+ submit_button.click(
415
+ process_question,
416
+ inputs=[question, document, model],
417
+ outputs=[image, output, output_text],
418
+ )
419
+
420
+ model.change(
421
+ process_question,
422
+ inputs=[question, document, model],
423
+ outputs=[image, output, output_text],
424
+ )
425
+
426
+ example_image.change(
427
+ fn=load_example_document,
428
+ inputs=[example_image, example_question, model],
429
+ outputs=[document, question, image, img_clear_button, output, output_text],
430
+ )
431
+
432
+ if __name__ == "__main__":
433
+ demo.launch(enable_queue=False)