|
from vietocr.model.backbone.cnn import CNN |
|
from vietocr.model.seqmodel.transformer import LanguageTransformer |
|
from vietocr.model.seqmodel.seq2seq import Seq2Seq |
|
from vietocr.model.seqmodel.convseq2seq import ConvSeq2Seq |
|
from torch import nn |
|
|
|
class VietOCR(nn.Module): |
|
def __init__(self, vocab_size, |
|
backbone, |
|
cnn_args, |
|
transformer_args, seq_modeling='transformer'): |
|
|
|
super(VietOCR, self).__init__() |
|
|
|
self.cnn = CNN(backbone, **cnn_args) |
|
self.seq_modeling = seq_modeling |
|
|
|
if seq_modeling == 'transformer': |
|
self.transformer = LanguageTransformer(vocab_size, **transformer_args) |
|
elif seq_modeling == 'seq2seq': |
|
self.transformer = Seq2Seq(vocab_size, **transformer_args) |
|
elif seq_modeling == 'convseq2seq': |
|
self.transformer = ConvSeq2Seq(vocab_size, **transformer_args) |
|
else: |
|
raise('Not Support Seq Model') |
|
|
|
def forward(self, img, tgt_input, tgt_key_padding_mask): |
|
""" |
|
Shape: |
|
- img: (N, C, H, W) |
|
- tgt_input: (T, N) |
|
- tgt_key_padding_mask: (N, T) |
|
- output: b t v |
|
""" |
|
src = self.cnn(img) |
|
|
|
if self.seq_modeling == 'transformer': |
|
outputs = self.transformer(src, tgt_input, tgt_key_padding_mask=tgt_key_padding_mask) |
|
elif self.seq_modeling == 'seq2seq': |
|
outputs = self.transformer(src, tgt_input) |
|
elif self.seq_modeling == 'convseq2seq': |
|
outputs = self.transformer(src, tgt_input) |
|
return outputs |
|
|
|
|