|
|
|
|
|
import gradio as gr |
|
from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
import torch |
|
import re |
|
import json |
|
import os |
|
import warnings |
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") |
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
|
|
|
|
model_path_finetuned = "greene6517/finetuned_donut_sroie" |
|
model_name_base = "naver-clova-ix/donut-base" |
|
|
|
|
|
print(f"Loading Fine-tuned processor from Hub: {model_path_finetuned}") |
|
try: |
|
|
|
processor = DonutProcessor.from_pretrained(model_path_finetuned) |
|
print("Successfully loaded fine-tuned processor from Hub.") |
|
except Exception as e: |
|
print(f"FATAL: Could not load fine-tuned processor from Hub: {e}") |
|
exit() |
|
|
|
print(f"Loading Fine-tuned model from Hub: {model_path_finetuned}") |
|
try: |
|
|
|
model_finetuned = VisionEncoderDecoderModel.from_pretrained(model_path_finetuned) |
|
print("Successfully loaded fine-tuned model from Hub.") |
|
except Exception as e: |
|
print(f"FATAL: Could not load fine-tuned model from Hub: {e}") |
|
exit() |
|
|
|
|
|
print(f"Loading Fine-tuned model from: {model_path_finetuned}") |
|
try: |
|
|
|
model_finetuned = VisionEncoderDecoderModel.from_pretrained(model_path_finetuned, local_files_only=True) |
|
print("Successfully loaded fine-tuned model locally from Space repo.") |
|
except Exception as e: |
|
print(f"Error loading fine-tuned model locally: {e}. Check if model files exist at the path.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"FATAL: Could not load fine-tuned model locally: {e}") |
|
exit() |
|
|
|
|
|
|
|
print(f"Loading Base processor from: {model_name_base}") |
|
try: |
|
processor_base = DonutProcessor.from_pretrained(model_name_base) |
|
print("Successfully loaded base processor.") |
|
except Exception as e: |
|
print(f"FATAL: Could not load base processor: {e}") |
|
exit() |
|
|
|
print(f"Loading Base model from: {model_name_base}") |
|
try: |
|
model_base = VisionEncoderDecoderModel.from_pretrained(model_name_base) |
|
print("Successfully loaded base model.") |
|
except Exception as e: |
|
print(f"FATAL: Could not load base model: {e}") |
|
exit() |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"\nUsing device: {device}") |
|
|
|
|
|
try: |
|
model_finetuned.to(device) |
|
model_base.to(device) |
|
print("Models moved to device.") |
|
|
|
model_finetuned.eval() |
|
model_base.eval() |
|
print("Models set to evaluation mode.") |
|
except Exception as e: |
|
print(f"Error moving models to device or setting eval mode: {e}") |
|
exit() |
|
|
|
|
|
|
|
def clean_sequence(sequence, processor_to_use, prompt_token_str=None): |
|
"""Removes prompt, EOS, PAD tokens from a generated sequence.""" |
|
cleaned = sequence |
|
try: |
|
|
|
eos_token = processor_to_use.tokenizer.eos_token if processor_to_use.tokenizer.eos_token else "</s>" |
|
pad_token = processor_to_use.tokenizer.pad_token if processor_to_use.tokenizer.pad_token else "<pad>" |
|
cleaned = cleaned.replace(eos_token, "").replace(pad_token, "").strip() |
|
|
|
|
|
if hasattr(processor_to_use.tokenizer, 'bos_token') and processor_to_use.tokenizer.bos_token: |
|
cleaned = cleaned.replace(processor_to_use.tokenizer.bos_token, "").strip() |
|
|
|
|
|
if prompt_token_str: |
|
|
|
if cleaned.startswith(prompt_token_str): |
|
cleaned = cleaned[len(prompt_token_str):].strip() |
|
|
|
|
|
|
|
except Exception as e: |
|
print(f"Warning: Error during sequence cleaning: {e}") |
|
return sequence |
|
return cleaned |
|
|
|
|
|
def token2json_simple(text): |
|
"""Parses <s_key>value</s_key> format into a dictionary.""" |
|
output = {} |
|
|
|
|
|
parts = re.findall(r"<s_(.*?)>([\s\S]*?)</s_\1>", text) |
|
for key, value in parts: |
|
|
|
output[key.strip()] = value.strip() |
|
|
|
|
|
if not output and text and not text.isspace(): |
|
output["parsing_info"] = "Could not parse SROIE key-value pairs from the cleaned sequence." |
|
output["cleaned_sequence_preview"] = text[:200] + "..." |
|
elif not text or text.isspace(): |
|
output["parsing_info"] = "Empty sequence after cleaning, nothing to parse." |
|
|
|
return output |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def process_image_comparison(image_input): |
|
if image_input is None: |
|
no_image_msg = {"error": "์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํด์ฃผ์ธ์."} |
|
|
|
return json.dumps(no_image_msg, indent=2, ensure_ascii=False), json.dumps(no_image_msg, indent=2, ensure_ascii=False) |
|
|
|
try: |
|
|
|
image = Image.fromarray(image_input).convert("RGB") |
|
except Exception as e: |
|
error_msg = {"error": f"์ด๋ฏธ์ง ๋ณํ ์ค๋ฅ: {e}"} |
|
error_json_str = json.dumps(error_msg, indent=2, ensure_ascii=False) |
|
return error_json_str, error_json_str |
|
|
|
results_ft_json_str = "{}" |
|
results_base_json_str = "{}" |
|
sequence_ft_raw = "N/A" |
|
sequence_base_raw = "N/A" |
|
|
|
|
|
try: |
|
pixel_values_ft = processor(image, return_tensors="pt").pixel_values.to(device) |
|
task_prompt_ft = "<s_sroie>" |
|
decoder_input_ids_ft = processor.tokenizer( |
|
task_prompt_ft, add_special_tokens=False, return_tensors="pt" |
|
).input_ids.to(device) |
|
|
|
|
|
generation_config_ft = { |
|
"max_length": model_finetuned.config.decoder.max_position_embeddings, |
|
"pad_token_id": processor.tokenizer.pad_token_id, |
|
"eos_token_id": processor.tokenizer.eos_token_id, |
|
"use_cache": True, |
|
"bad_words_ids": [[processor.tokenizer.unk_token_id]] if processor.tokenizer.unk_token_id else None, |
|
"return_dict_in_generate": True, |
|
"decoder_input_ids": decoder_input_ids_ft |
|
} |
|
|
|
outputs_ft = model_finetuned.generate(pixel_values_ft, **generation_config_ft) |
|
|
|
sequence_ft_raw = processor.batch_decode(outputs_ft.sequences)[0] |
|
|
|
|
|
|
|
sequence_ft_cleaned = clean_sequence(sequence_ft_raw, processor, prompt_token_str=task_prompt_ft) |
|
|
|
|
|
|
|
result_json_ft = token2json_simple(sequence_ft_cleaned) |
|
result_json_ft["raw_decoded_sequence_preview"] = sequence_ft_raw[:200] + "..." |
|
|
|
|
|
results_ft_json_str = json.dumps(result_json_ft, indent=2, ensure_ascii=False, sort_keys=False) |
|
|
|
except Exception as e: |
|
print(f"Error during fine-tuned model inference: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
results_ft_json_str = json.dumps({ |
|
"error": f"Fine-tuned ๋ชจ๋ธ ์ถ๋ก ์ค๋ฅ: {e}", |
|
"raw_decoded_sequence_before_error": sequence_ft_raw |
|
}, indent=2, ensure_ascii=False) |
|
|
|
|
|
try: |
|
pixel_values_base = processor_base(image, return_tensors="pt").pixel_values.to(device) |
|
|
|
|
|
task_prompt_base = "<s_iitcdip>" |
|
|
|
|
|
try: |
|
decoder_input_ids_base = processor_base.tokenizer( |
|
task_prompt_base, |
|
add_special_tokens=False, |
|
return_tensors="pt", |
|
).input_ids.to(device) |
|
except Exception as tokenizer_e: |
|
print(f"Warning: Base processor cannot tokenize prompt '{task_prompt_base}'. Using default generation. Error: {tokenizer_e}") |
|
decoder_input_ids_base = None |
|
|
|
|
|
generation_config_base = { |
|
"max_length": model_base.config.decoder.max_position_embeddings, |
|
"early_stopping": True, |
|
"pad_token_id": processor_base.tokenizer.pad_token_id, |
|
"eos_token_id": processor_base.tokenizer.eos_token_id, |
|
"use_cache": True, |
|
"num_beams": 1, |
|
"bad_words_ids": [[processor_base.tokenizer.unk_token_id]] if processor_base.tokenizer.unk_token_id else None, |
|
"return_dict_in_generate": True, |
|
} |
|
|
|
if decoder_input_ids_base is not None: |
|
generation_config_base["decoder_input_ids"] = decoder_input_ids_base |
|
|
|
outputs_base = model_base.generate(pixel_values_base, **generation_config_base) |
|
|
|
sequence_base_raw = processor_base.batch_decode(outputs_base.sequences)[0] |
|
|
|
|
|
|
|
sequence_base_cleaned = processor_base.batch_decode(outputs_base.sequences, skip_special_tokens=True)[0] |
|
|
|
|
|
|
|
result_json_base = { |
|
"raw_decoded_sequence_preview": sequence_base_raw[:200] + "...", |
|
"output_skip_special_tokens": sequence_base_cleaned |
|
} |
|
|
|
results_base_json_str = json.dumps(result_json_base, indent=2, ensure_ascii=False, sort_keys=False) |
|
|
|
except Exception as e: |
|
print(f"Error during base model inference: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
results_base_json_str = json.dumps({ |
|
"error": f"Base ๋ชจ๋ธ ์ถ๋ก ์ค๋ฅ: {e}", |
|
"raw_decoded_sequence_before_error": sequence_base_raw |
|
}, indent=2, ensure_ascii=False) |
|
|
|
|
|
return results_ft_json_str, results_base_json_str |
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
body { background-color: #f0f4f8; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } |
|
#main_title { text-align: center; color: #1a5276; font-size: 2.3em; font-weight: 600; margin-top: 20px; margin-bottom: 5px; } |
|
#sub_description { text-align: center; color: #566573; font-size: 1.0em; margin-bottom: 25px; } |
|
.gradio-container { border-radius: 10px !important; box-shadow: 0 3px 10px rgba(0,0,0,0.08); padding: 25px !important; } |
|
footer { display: none !important; } /* Hide Gradio footer */ |
|
#output-title-ft, #output-title-base { color: #1a5276; font-weight: 600; margin-bottom: 8px; font-size: 1.2em; border-bottom: 2px solid #aed6f1; padding-bottom: 4px; } |
|
#output_row > div.gradio-column { border: 1px solid #d5dbdb; padding: 15px !important; border-radius: 8px; background-color: #ffffff; margin: 0 8px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.04); } |
|
#json_output_ft > div:nth-child(2), #json_output_base > div:nth-child(2) { max-height: 600px; overflow-y: auto !important; } /* JSON output scroll */ |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo: |
|
gr.Markdown("# Donut ๋ชจ๋ธ ๋น๊ต: Fine-tuned vs Base", elem_id="main_title") |
|
gr.Markdown("์์์ฆ ์ด๋ฏธ์ง๋ฅผ ์
๋ก๋ํ๋ฉด Fine-tuned ๋ชจ๋ธ(SROIE ํ์ฑ)๊ณผ Base ๋ชจ๋ธ์ ์ถ์ถ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํฉ๋๋ค.", elem_id="sub_description") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
image_input = gr.Image(type="numpy", label="๐งพ ์์์ฆ ์ด๋ฏธ์ง ์
๋ก๋") |
|
submit_btn = gr.Button("๐ ๊ฒฐ๊ณผ ๋น๊ต ์์", variant="primary", scale=0) |
|
|
|
|
|
example_img_dir = "example" |
|
|
|
example_paths = [os.path.join(example_img_dir, f) for f in ["1.jpg", "2.jpg"] if os.path.exists(os.path.join(example_img_dir, f))] |
|
if example_paths: |
|
gr.Examples(examples=example_paths, inputs=image_input, label="์์ ์ด๋ฏธ์ง ํด๋ฆญ (ํด๋ฆญ ํ '๊ฒฐ๊ณผ ๋น๊ต ์์' ๋ฒํผ ๋๋ฅด์ธ์)") |
|
else: |
|
gr.Markdown("_(์์ ์ด๋ฏธ์ง๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. 'example' ํด๋ ํ์ธ ํ์)_") |
|
|
|
with gr.Column(scale=2): |
|
with gr.Row(elem_id="output_row"): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### โจ Fine-tuned Model (SROIE ํ์ฑ)", elem_id="output-title-ft") |
|
|
|
json_output_ft = gr.JSON(label="Fine-tuned ๊ฒฐ๊ณผ (JSON)", elem_id="json_output_ft") |
|
with gr.Column(scale=1): |
|
gr.Markdown("### ๐ก Base Model (Raw + Cleaned)", elem_id="output-title-base") |
|
json_output_base = gr.JSON(label="Base ๋ชจ๋ธ ๊ฒฐ๊ณผ (JSON)", elem_id="json_output_base") |
|
|
|
|
|
submit_btn.click( |
|
fn=process_image_comparison, |
|
inputs=image_input, |
|
outputs=[json_output_ft, json_output_base] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch() |