|
import os |
|
import gradio as gr |
|
import omegaconf |
|
import torch |
|
import numpy |
|
|
|
import easyocr |
|
from PIL import Image |
|
|
|
from vietocr.model.transformerocr import VietOCR |
|
from vietocr.model.vocab import Vocab |
|
from vietocr.translate import translate, process_input |
|
|
|
reader = easyocr.Reader(['vi']) |
|
|
|
examples_data = os.listdir('examples') |
|
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data] |
|
|
|
config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml") |
|
config = omegaconf.OmegaConf.to_container(config, resolve=True) |
|
|
|
vocab = Vocab(config['vocab']) |
|
model = VietOCR(len(vocab), |
|
config['backbone'], |
|
config['cnn'], |
|
config['transformer'], |
|
config['seq_modeling']) |
|
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu'))) |
|
def viet_ocr_predict(inp): |
|
img = process_input(inp, config['dataset']['image_height'], |
|
config['dataset']['image_min_width'], config['dataset']['image_max_width']) |
|
out = translate(img, model)[0].tolist() |
|
out = vocab.decode(out) |
|
return out |
|
def predict(filepath): |
|
bounds = reader.readtext(filepath) |
|
im = Image.open(filepath) |
|
inp = numpy.asarray(im) |
|
|
|
|
|
|
|
width, height, _ = inp.shape |
|
if width>height: |
|
height, width, _ = inp.shape |
|
|
|
texts='' |
|
for (bbox, text, prob) in bounds: |
|
(tl, tr, br, bl) = bbox |
|
tl = (int(tl[0]), int(tl[1])) |
|
tr = (int(tr[0]), int(tr[1])) |
|
br = (int(br[0]), int(br[1])) |
|
bl = (int(bl[0]), int(bl[1])) |
|
|
|
min_x = min(tl[0], tr[0], br[0], bl[0]) |
|
min_x = max(0, min_x) |
|
max_x = max(tl[0], tr[0], br[0], bl[0]) |
|
max_x = min(width-1, max_x) |
|
min_y = min(tl[1], tr[1], br[1], bl[1]) |
|
min_y = max(0, min_y) |
|
max_y = max(tl[1], tr[1], br[1], bl[1]) |
|
max_y = min(height-1, max_y) |
|
|
|
try: |
|
cropped_image = inp[min_y:max_y,min_x:max_x,:] |
|
cropped_image = Image.fromarray(cropped_image) |
|
out = viet_ocr_predict(cropped_image) |
|
except: |
|
out = text |
|
print(out) |
|
texts = texts + '\t' + out |
|
|
|
return texts |
|
|
|
gr.Interface(fn=predict, |
|
title='Vietnamese Handwriting Recognition', |
|
inputs=gr.Image(type='filepath'), |
|
outputs=gr.Text(), |
|
|
|
).launch() |