File size: 3,317 Bytes
d66c48f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).

# ## Citations

# ```bibtex
# @inproceedings{yao2021wenet,
#   title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
#   author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
#   booktitle={Proc. Interspeech},
#   year={2021},
#   address={Brno, Czech Republic },
#   organization={IEEE}
# }

# @article{zhang2022wenet,
#   title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
#   author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
#   journal={arXiv preprint arXiv:2203.15455},
#   year={2022}
# }
#

import torch
import torch.nn.functional as F


class CTC(torch.nn.Module):
    """CTC module"""

    def __init__(

        self,

        odim: int,

        encoder_output_size: int,

        dropout_rate: float = 0.0,

        reduce: bool = True,

    ):
        """Construct CTC module

        Args:

            odim: dimension of outputs

            encoder_output_size: number of encoder projection units

            dropout_rate: dropout rate (0.0 ~ 1.0)

            reduce: reduce the CTC loss into a scalar

        """
        super().__init__()
        eprojs = encoder_output_size
        self.dropout_rate = dropout_rate
        self.ctc_lo = torch.nn.Linear(eprojs, odim)

        reduction_type = "sum" if reduce else "none"
        self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)

    def forward(

        self,

        hs_pad: torch.Tensor,

        hlens: torch.Tensor,

        ys_pad: torch.Tensor,

        ys_lens: torch.Tensor,

    ) -> torch.Tensor:
        """Calculate CTC loss.



        Args:

            hs_pad: batch of padded hidden state sequences (B, Tmax, D)

            hlens: batch of lengths of hidden state sequences (B)

            ys_pad: batch of padded character id sequence tensor (B, Lmax)

            ys_lens: batch of lengths of character sequence (B)

        """
        # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
        ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
        # ys_hat: (B, L, D) -> (L, B, D)
        ys_hat = ys_hat.transpose(0, 1)
        ys_hat = ys_hat.log_softmax(2)
        loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens)
        # Batch-size average
        loss = loss / ys_hat.size(1)
        return loss

    def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
        """log_softmax of frame activations



        Args:

            Tensor hs_pad: 3d tensor (B, Tmax, eprojs)

        Returns:

            torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)

        """
        return F.log_softmax(self.ctc_lo(hs_pad), dim=2)

    def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor:
        """argmax of frame activations



        Args:

            torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)

        Returns:

            torch.Tensor: argmax applied 2d tensor (B, Tmax)

        """
        return torch.argmax(self.ctc_lo(hs_pad), dim=2)