VietOCR / app.py
hantech's picture
Update app.py
79202d5 verified
raw
history blame contribute delete
2.45 kB
import os
import gradio as gr
import omegaconf
import torch
import numpy
import easyocr
from PIL import Image
from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.translate import translate, process_input
reader = easyocr.Reader(['vi'])
examples_data = os.listdir('examples')
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]
config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
config = omegaconf.OmegaConf.to_container(config, resolve=True)
vocab = Vocab(config['vocab'])
model = VietOCR(len(vocab),
config['backbone'],
config['cnn'],
config['transformer'],
config['seq_modeling'])
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))
def viet_ocr_predict(inp):
img = process_input(inp, config['dataset']['image_height'],
config['dataset']['image_min_width'], config['dataset']['image_max_width'])
out = translate(img, model)[0].tolist()
out = vocab.decode(out)
return out
def predict(filepath):
bounds = reader.readtext(filepath)
im = Image.open(filepath)
inp = numpy.asarray(im)
#inp = cv2.imread(filepath)
width, height, _ = inp.shape
if width>height:
height, width, _ = inp.shape
texts=''
for (bbox, text, prob) in bounds:
(tl, tr, br, bl) = bbox
tl = (int(tl[0]), int(tl[1]))
tr = (int(tr[0]), int(tr[1]))
br = (int(br[0]), int(br[1]))
bl = (int(bl[0]), int(bl[1]))
min_x = min(tl[0], tr[0], br[0], bl[0])
min_x = max(0, min_x)
max_x = max(tl[0], tr[0], br[0], bl[0])
max_x = min(width-1, max_x)
min_y = min(tl[1], tr[1], br[1], bl[1])
min_y = max(0, min_y)
max_y = max(tl[1], tr[1], br[1], bl[1])
max_y = min(height-1, max_y)
# crop the region of interest (ROI)
try:
cropped_image = inp[min_y:max_y,min_x:max_x,:] # crop the image
cropped_image = Image.fromarray(cropped_image)
out = viet_ocr_predict(cropped_image)
except:
out = text
print(out)
texts = texts + '\t' + out
return texts
gr.Interface(fn=predict,
title='Vietnamese Handwriting Recognition',
inputs=gr.Image(type='filepath'),
outputs=gr.Text(),
#examples=examples_data,
).launch()