其他分享
首页 > 其他分享> > CTCLoss如何使用

CTCLoss如何使用

作者:互联网

CTCLoss如何使用

目录

什么是CTC

CTC全称为Connectionist Temporal Classification,中文翻译不好类似“联结主义按时间分类”。

CTCLoss是一类损失函数,用于计算模型输出\(y\)和标签\(label\)的损失。

\[loss=CTCLoss(y,label) \]

神经网络在训练过程中,是让\(loss\)减少的过程。常用于图片文字识别OCR和语音识别项目,因为CTCLoss计算过程中不需要\(y\)和\(label\)对齐,这样做的好处就是大幅的减轻了数据对齐标注的工作量,极大的提高了效率。

架构介绍

本文主要是介绍CTCLoss,这里介绍模型架构是为了更好的理解CTCLoss函数在整体的做用。现有一段原始数据,它可以是一张带文字的图片或一段说话的音频。
dddd
如图所示原始的声音通过DFT(离散傅立叶变化)得到一张具有时频特性的特征图,将特征图通过网络\(\mathcal{N}_w\)后输出结果\(y\)(\(y\in\mathbb{R}^{K \times T}\),\(K\)维是在每一时间点预测的词的概率,\(T\)是时间维度)。

一个简单的例子

现在有一段语音,是一个人在拼写英文单词“CAT”,语音内容是“C”、“A”、“T”这三个字母。这个人读完这三个字母用了5s的时间。我们想通过语音识别这三个字母。

首先我们需要一个26个字母的词表,我们用序号1-26,分别来表示字母A-Z这26个字母,我们用序号0表示blankblank是用来区分那些不属于这26字母的部分。然后是假设这个模型每秒会给出一个识别字母表的概率分布,
音频持续了5s,因此有5列这样的概率分布。

\[y\in\mathbb{R}^{K \times T}\ \ \ \ \ (K=27,\ T=5) \]

下表就是\(y\)的概率分布,每一列是当前时刻输入数据所对应的概率分布。

表1 每个时刻输出字符的概率分布
\(y_t^k\) t=1 t=2 t=3 t=4 t=5
k=0 (-) 0.031953 0.044296 0.038297 0.038320 0.027464
k=1 (A) 0.026221 0.030363 0.031878 0.027295 0.029824
k=2 (B) 0.040555 0.025838 0.023487 0.041529 0.028116
k=3 (C) 0.029333 0.045889 0.031872 0.023184 0.029338
k=4 (D) 0.023595 0.053792 0.022519 0.039882 0.025342
k=5 (E) 0.048014 0.028887 0.020526 0.041302 0.045833
k=6 (F) 0.028770 0.040735 0.045488 0.044244 0.032191
k=7 (G) 0.035127 0.032281 0.034032 0.051973 0.041613
k=8 (H) 0.044897 0.047910 0.049222 0.056956 0.048665
k=9 (I) 0.032323 0.044911 0.038994 0.046017 0.040002
k=10 (J) 0.047130 0.024608 0.034797 0.038146 0.041496
k=11 (K) 0.033491 0.049294 0.043909 0.053962 0.037901
k=12 (L) 0.044700 0.056019 0.046794 0.038094 0.027488
k=13 (M) 0.045632 0.034822 0.052229 0.021692 0.039653
k=14 (N) 0.035123 0.050406 0.019438 0.024067 0.056986
k=15 (O) 0.023015 0.037482 0.046163 0.050536 0.058191
k=16 (P) 0.031419 0.024302 0.035848 0.034614 0.031820
k=17 (Q) 0.034497 0.025424 0.052284 0.049642 0.029912
k=18 (R) 0.029572 0.031274 0.032931 0.026295 0.042725
k=19 (S) 0.027484 0.044015 0.031383 0.037050 0.046068
k=20 (T) 0.051330 0.047532 0.043297 0.040039 0.036849
k=21 (U) 0.034691 0.045869 0.024400 0.022020 0.029838
k=22 (V) 0.054835 0.028627 0.031971 0.039436 0.062661
k=23 (W) 0.033373 0.035513 0.047827 0.030642 0.026361
k=24 (X) 0.048700 0.022777 0.034515 0.022410 0.026991
k=25 (Y) 0.033561 0.023278 0.045237 0.034797 0.027990
k=26 (Z) 0.050657 0.023858 0.040665 0.025854 0.028682

上面的例子已经给出了网络\(\mathcal{N}_w\)输出\(y\)的描述,与这段音频所对应的标签\(label\),应该是‘C’、‘A’、‘T’这三个字母,将它转换成用字母表中序号表示

\[label=[3,1,20] \]

CTC计算的推导

论文中CTCLoss的计算公式为

\[O^{ML}(S,\mathcal{N}_w)=-\sum_{x,z \in S}ln(p(z|x)) \]

那上面这个公式表示的含义是什么呢?

总概率\(p(z|x)\)

CTCLoss中最关键的就是计算每一条样本\({\{x,z\}} \in S\)的条件概率\(p(z|x)\),\(z\)是目标标签与\(x\)是一一对应关系,\(l\)是任意标签只要是符合字母表规则的标签都是可以的,而\(z\)只是符合\(l\)规则中的一条。在训练的时候可以指定\(l=z\),但在公式推导时应该更严谨更泛化一些。因此\(p(z|x)\)可以用作\(p(l|x)\)替代,下面给出\(p(l|x)\)的计算公式

\[p(l|x)=\sum_{\pi \in \mathcal{B}^{-1}(l)}{p(\pi|x)} \]

路径的含义

已知网络\(\mathcal{N}_w\)的输出\(x\in\mathbb{R}^{K \times T}\),它有\(T\)个时间点,并在每个时间点中有\(K\)种输出的可能,一共有\(K^ T\)条路径。在上面的例子中\(K=27,T=5\)所以一共就有\(27^5=14348907\)条可能的路径。仅仅\(T=5\)时,总路径条数已经相当的巨大了。

路径概率\(p(\pi|x)\)

表1已经给出于每个时刻所有的字母概率,由每个时刻选出的字母将组成一条路径,那么这条路径的概率就等于各个时刻选择字母的概率的乘积。

\[\begin{aligned} p(\pi|x)&=\prod_{t=1}^{T}{y_{k=\pi^t}^t} \\ &=y_{k=\pi^1}^1\times y_{k=\pi^2}^2\times y_{k=\pi^3}^3\times...\times y_{k=\pi^T}^T \end{aligned}\]

什么是\(\mathcal{B}\)变换

在上面提到的\(27^5\)条路径中\(\mathcal{B}\)变换就是将路径中所有的blank\((-)\),和相邻重复的元素删除,比如

\[\mathcal{B}(a − ab−) = \mathcal{B}(−aa − −abb) = aab \]

\[\mathcal{B}(C − AT−) = \mathcal{B}(CC-AT) = CAT \]

同理符号\(\mathcal{B}^{-1}(l)\)则是\(\mathcal{B}(\pi)\)的逆变换。表示所有满足\(\mathcal{B}(\pi)=l\)的路径

\(p(l|x)\)并不是计算所有路径的概率之和,而是计算所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。

一步一步手动计算CTCLoss

现在就根据上面提供的例子,一步一步手动计算CTCLoss

找出所有满足\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”的路径

在上面给出的\(27^5\)条路径中给出的符合\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”共有28条,
如表2所示

表2 所有满足条件的路径,共28条
t=1 t=2 t=3 t=4 t=5
\(\pi_{1}\) - - C A T
\(\pi_{2}\) - C - A T
\(\pi_{3}\) - C C A T
\(\pi_{4}\) - C A - T
\(\pi_{5}\) - C A A T
\(\pi_{6}\) - C A T -
\(\pi_{7}\) - C A T T
\(\pi_{8}\) C - - A T
\(\pi_{9}\) C - A - T
\(\pi_{10}\) C - A A T
\(\pi_{11}\) C - A T -
\(\pi_{12}\) C - A T T
\(\pi_{13}\) C C - A T
\(\pi_{14}\) C C C A T
\(\pi_{15}\) C C A - T
\(\pi_{16}\) C C A A T
\(\pi_{17}\) C C A T -
\(\pi_{18}\) C C A T T
\(\pi_{19}\) C A - - T
\(\pi_{20}\) C A - T -
\(\pi_{21}\) C A - T T
\(\pi_{22}\) C A A - T
\(\pi_{23}\) C A A A T
\(\pi_{24}\) C A A T -
\(\pi_{25}\) C A A T T
\(\pi_{26}\) C A T - -
\(\pi_{27}\) C A T T -
\(\pi_{28}\) C A T T T

计算每条路径的概率\(p(\pi|x)\)

路径\(\pi_1\)所对应的标签为"- - C A T",这段序列转换为字母表中的索引,
则路径\(\pi_1\)在每个时刻的取值如下

\[y_{k=\pi^1_1}^1=y_{0}^1=0.031953 \]

\[y_{k=\pi^2_1}^1=y_{0}^2=0.044296 \]

\[y_{k=\pi^3_1}^1=y_{3}^3=0.031872 \]

\[y_{k=\pi^4_1}^1=y_{1}^4=0.027295 \]

\[y_{k=\pi^5_1}^1=y_{20}^5=0.036849 \]

因此路径\(\pi_1的概率\)\(p(\pi_1|x)\)的计算如下

\[\begin{aligned} p(\pi_1|x)&=\prod_{t=1}^{T}{y_{k=\pi_1^t}^t} \\ &=y_{k=\pi_1^1}^1\times y_{k=\pi_1^2}^2 \times y_{k=\pi_1^3}^3 \times ... \times y_{k=\pi_1^T}^T \\ &=y_{0}^1 \times y_{0}^2 \times y_{3}^3 \times y_{1}^4\times y_{20}^5 \\ &=0.031953 \times0.044296\times0.031872\times0.027295\times0.036849 \\ &=4.5373e^{-8} \end{aligned}\]

同理可计算

\[p(\pi_1|x)=4.5374e^{-8}, p(\pi_2|x)=5.6482e^{-8}, p(\pi_3|x)=4.7006e^{-8}, p(\pi_4|x)=6.6003e^{-8}\]

\[p(\pi_5|x)=4.7014e^{-8}, p(\pi_6|x)=5.1401e^{-8}, p(\pi_7|x)=6.8965e^{-8}, p(\pi_8|x)=5.0050e^{-8}\]

\[p(\pi_9|x)=5.8487e^{-8}, p(\pi_{10}|x)=4.1660e^{-8}, p(\pi_{11}|x)=4.5547e^{-8}, p(\pi_{12}|x)=6.1111e^{-8}\]

\[p(\pi_{13}|x)=5.1850e^{-8}, p(\pi_{14}|x)=4.3151e^{-8}, p(\pi_{15}|x)=6.0590e^{-8}, p(\pi_{16}|x)=4.3158e^{-8}\]

\[p(\pi_{17}|x)=4.7185e^{-8}, p(\pi_{18}|x)=6.3309e^{-8}, p(\pi_{19}|x)=4.8163e^{-8}, p(\pi_{20}|x)=3.7508e^{-8}\]

\[p(\pi_{21}|x)=5.0324e^{-8}, p(\pi_{22}|x)=4.0090e^{-8}, p(\pi_{23}|x)=2.8556e^{-8}, p(\pi_{24}|x)=3.1220e^{-8}\]

\[p(\pi_{25}|x)=4.1889e^{-8}, p(\pi_{26}|x)=4.0583e^{-8}, p(\pi_{27}|x)=4.2404e^{-8}, p(\pi_{28}|x)=5.6894e^{-8}\]

计算总概率\(p(l|x)\)

\(p(l|x)\)是所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。

\[\begin{aligned} p(l|x)&=\sum_{\pi \in \mathcal{B}^{-1}(l)}{p(\pi|x)} \\ &=p(\pi_1|x)+p(\pi_2|x)+p(\pi_1|x)+...+p(\pi_{28}|x) \\ &=4.5374e^{-8} + 5.6482e^{-8}+4.7006e^{-8}+...+5.6894e^{-8} \\ &=1.366e^{-6} \end{aligned}\]

计算损失函数CTCLoss

由于例子中只给了1样本,所以下面的损失函数CTCLoss也就只有这一个样本的损失。

\[\begin{aligned} O^{ML}(S,\mathcal{N}_w)&=-ln{(S,\mathcal{N}_w)} \\ &=-\sum_{x,z \in S}ln(p(z|x)) \\ &=-ln(p(z|x)) \\ &=-ln(1.366e^{-6}) \\ &=\ 13.5036 \end{aligned}\]

CTCLoss库函数的验证

网络\(\mathcal{N}_w\)输出\(y\_out\)的softmax处理

这里有一点需要解释一下,CTCLoss的输入\(ctc\_input\)与网络\(\mathcal{N}_w\)的输出\(y\_out\)之间的关系。

在网络\(\mathcal{N}_w\)输出的最后一级是没有softmax,所以\(y\_out\)在每一个时间点的的概率和都不为1,为了将概率分布归一化需要将\(y\)进行softmax计算。

\[y\_softmax=softmax(y\_out) \]

同时CTCLoss中包含有大量的概率的乘法运算,需要将\(y\_softmax\)进行\(ln\)计算,
这样可以将乘法转换为加法计算,提升计算的速度。

\[ctc\_input=ln(y\_softmax) \]

上面的例子,为了让文档更直观,已经默认

\[y=y\_softmax \]

下表就是\(y\_out\),显然每一列之和不为1。

\(y\_out_t^k\) t=1 t=2 t=3 t=4 t=5
k=0 (-) 0.347713 0.755077 0.678652 0.585987 0.123084
k=1 (A) 0.149997 0.377396 0.495177 0.246735 0.205494
k=2 (B) 0.586092 0.216019 0.189710 0.666416 0.146515
k=3 (C) 0.262145 0.790407 0.495006 0.083483 0.189072
k=4 (D) 0.044454 0.949304 0.147608 0.625960 0.042652
k=5 (E) 0.754933 0.327565 0.054974 0.660945 0.635198
k=6 (F) 0.242785 0.671264 0.850713 0.729752 0.281867
k=7 (G) 0.442402 0.438645 0.560560 0.890752 0.538597
k=8 (H) 0.687796 0.833501 0.929609 0.982303 0.695163
k=9 (I) 0.359228 0.768854 0.696667 0.769029 0.499116
k=10 (J) 0.736340 0.167254 0.582791 0.581446 0.535801
k=11 (K) 0.394707 0.861980 0.815397 0.928313 0.445183
k=12 (L) 0.683416 0.989872 0.879014 0.580090 0.123932
k=13 (M) 0.704047 0.514423 0.988912 0.016983 0.490357
k=14 (N) 0.442305 0.884281 0.000522 0.120860 0.852998
k=15 (O) 0.019578 0.588026 0.865439 0.862711 0.873927
k=16 (P) 0.330858 0.154752 0.612566 0.484297 0.270294
k=17 (Q) 0.424309 0.199863 0.989950 0.844856 0.208461
k=18 (R) 0.270270 0.406955 0.527680 0.209405 0.564980
k=19 (S) 0.197054 0.748706 0.479523 0.552291 0.640312
k=20 (T) 0.821721 0.825584 0.801348 0.629883 0.417029
k=21 (U) 0.429921 0.789963 0.227843 0.031991 0.205976
k=22 (V) 0.887771 0.318524 0.498094 0.614713 0.947933
k=23 (W) 0.391183 0.534064 0.900852 0.362411 0.082071
k=24 (X) 0.769114 0.089951 0.574661 0.049533 0.105709
k=25 (Y) 0.396792 0.111706 0.845178 0.489570 0.142041
k=26 (Z) 0.808514 0.136293 0.738640 0.192510 0.166460

pytorch库函数验证

CTCLoss使用细节可以参考pytorch官方文档

import torch
import torch.nn as nn
import numpy as np

y_softmax = np.array([
    [[0.0319533345695271, 0.0262210133693412, 0.0405548727460100, 0.0293328834922530, 0.0235946021815836, 0.0480142162870594, 0.0287704618407728, 0.0351268637054168, 0.0448965052477630, 0.0323234212279283, 0.0471297269219778, 0.0334908192070999, 0.0447002788315031, 0.0456320948241136,
        0.0351234600906292, 0.0230148922614546, 0.0314192811142228, 0.0344970346892286, 0.0295721871384341, 0.0274843752526059, 0.0513304969210734, 0.0346911732659917, 0.0548353372646645, 0.0333729892573427, 0.0486999624899632, 0.0335606882517763, 0.0506570275502634]],
    [[0.0442961938109001, 0.0303627704208565, 0.0258378526020265, 0.0458891577161975, 0.0537920435977104, 0.0288868677848477, 0.0407349328912650, 0.0322806067098565, 0.0479099042067772, 0.0449106925711146, 0.0246080887866719, 0.0492939884049119, 0.0560191619281624, 0.0348218517081914,
        0.0504056201105211, 0.0374815087428365, 0.0243023731122621, 0.0254237678526359, 0.0312736688595233, 0.0440148630768450, 0.0475321094768427, 0.0458687788283468, 0.0286268732637606, 0.0355125367928648, 0.0227774801386588, 0.0232784351056503, 0.0238578714997625]],
    [[0.0382974368377362, 0.0318777312135849, 0.0234868589674224, 0.0318722744011979, 0.0225185381516373, 0.0205262552943881, 0.0454877627911883, 0.0340316234294017, 0.0492219436202117, 0.0389936131137926, 0.0347966678592871, 0.0439093761642613, 0.0467935124498177, 0.0522292290638150,
        0.0194384495697102, 0.0461625681675025, 0.0358483354617907, 0.0522835019782284, 0.0329308772273817, 0.0313826141807340, 0.0432967801742709, 0.0243997674509821, 0.0319708630090250, 0.0478266566415420, 0.0345149265806327, 0.0452367066323343, 0.0406651295681235]],
    [[0.0383195501689954, 0.0272951137973125, 0.0415288927451887, 0.0231838517718695, 0.0398823138441169, 0.0413022813256117, 0.0442442329310963, 0.0519730489462436, 0.0569558497142297, 0.0460166028890008, 0.0381459528257684, 0.0539623316283564, 0.0380942573161036, 0.0216922730554261,
        0.0240667868142706, 0.0505358960731075, 0.0346143968556499, 0.0496415831760055, 0.0262949856792733, 0.0370498580320465, 0.0400391034751884, 0.0220202876462848, 0.0394362954874324, 0.0306423990773223, 0.0224099657044701, 0.0347974172676594, 0.0258544717519696]],
    [[0.0274643882294982, 0.0298236175649923, 0.0281155092606543, 0.0293378537462717, 0.0253418924737544, 0.0458330002578632, 0.0321905618820226, 0.0416126048467898, 0.0486654573566434, 0.0400017201897758, 0.0414964341812715, 0.0379014590893513, 0.0274877024782956, 0.0396528862221281,
        0.0569859416555112, 0.0581911831104043, 0.0318201830875284, 0.0299122412570334, 0.0427250763149338, 0.0460679863549903, 0.0368492548068844, 0.0298379764585031, 0.0626610201269008, 0.0263607892806820, 0.0269913345266294, 0.0279900073565483, 0.0286819178841385]]
]).astype("float32")


labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")

ctc_input = torch.tensor(y_softmax).log()
labels = torch.tensor(labels)
input_lengths = torch.tensor(input_lengths)
label_lengths = torch.tensor(label_lengths)

ctc_loss = nn.CTCLoss(reduction='none')
loss = ctc_loss(ctc_input, labels, input_lengths, label_lengths)
print('loss is {}'.format(loss))

loss is tensor([13.5036])

paddle库函数的使用

CTCLoss使用细节可以参考
paddle官方文档

由于paddle的CTCLoss库底层已经实现了log_softmax,所以它的输入可以直接为\(y\_out\)

import numpy as np
import paddle
import paddle.nn.functional as F


y_out = np.array([
    [[0.347712671277525, 0.149997253831683, 0.586092067231462, 0.262145317727807, 0.0444540922782385, 0.754933267231179, 0.242785357820962, 0.442402313001943, 0.687796085120107, 0.359228210401861, 0.736340074301202, 0.394707475278763, 0.683415866967978, 0.704047430334266,
        0.442305413383371, 0.0195776235533187, 0.330857880214071, 0.424309496833137, 0.270270423432065, 0.197053798095456, 0.821721184961310, 0.429921409383266, 0.887770954256354, 0.391182995461163, 0.769114387388296, 0.396791517013617, 0.808514095887345]],
    [[0.755077099007084, 0.377395544835103, 0.216018915961394, 0.790407217966913, 0.949303911849797, 0.327565434075205, 0.671264370451740, 0.438644982586956, 0.833500595588975, 0.768854252429615, 0.167253545494722, 0.861980478702072, 0.989872153631504, 0.514423456505704,
        0.884281023126955, 0.588026055308498, 0.154752348656045, 0.199862822857452, 0.406954837138907, 0.748705718215691, 0.825583815786156, 0.789963029944531, 0.318524245398992, 0.534064127370726, 0.0899506787705811, 0.111705744193203, 0.136292548938299]],
    [[0.678652304800188, 0.495177019089661, 0.189710406017580, 0.495005824990221, 0.147608221976689, 0.0549741469061882, 0.850712674289007, 0.560559527354885, 0.929608866756663, 0.696667200555228, 0.582790965175840, 0.815397211477421, 0.879013904597178, 0.988911616079589,
        0.000522375356944771, 0.865438591013025, 0.612566469483999, 0.989950205708831, 0.527680069338442, 0.479523385210219, 0.801347605521952, 0.227842935706042, 0.498094291196390, 0.900852488532005, 0.574661219130188, 0.845178185054037, 0.738640291995402]],
    [[0.585987035826476, 0.246734525985975, 0.666416217319468, 0.0834828136026227, 0.625959785171583, 0.660944557947342, 0.729751855317221, 0.890752116325322, 0.982303222883606, 0.769029085335896, 0.581446487875398, 0.928313062314188, 0.580090365758442, 0.0169829383372613,
        0.120859571098558, 0.862710718699670, 0.484296511212103, 0.844855674576263, 0.209405084020935, 0.552291341538775, 0.629883385064421, 0.0319910157625669, 0.614713419117141, 0.362411462273053, 0.0495325790420612, 0.489569989177322, 0.192510396062075]],
    [[0.123083747545945, 0.205494170907680, 0.146514910614890, 0.189072174472614, 0.0426524109111434, 0.635197916859882, 0.281866855880430, 0.538596678045340, 0.695163039444332, 0.499116013482590, 0.535801055751113, 0.445183165296042, 0.123932277598070, 0.490357293468018,
        0.852998155340816, 0.873927405861733, 0.270294332292698, 0.208461358751314, 0.564979570738201, 0.640311825162758, 0.417028951642886, 0.205975515532243, 0.947933121293169, 0.0820712070977259, 0.105709426581721, 0.142041121903998, 0.166460440876421]]
]).astype("float32")

labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64")


y_out=paddle.to_tensor(y_out)
labels = paddle.to_tensor(labels)
input_lengths = paddle.to_tensor(input_lengths)
label_lengths = paddle.to_tensor(label_lengths)


loss = paddle.nn.CTCLoss(blank=0, reduction='none')(y_out, labels,
                                                    input_lengths,
                                                    label_lengths)
print('loss is {}'.format(loss))

loss is Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
       [13.50364304])

标签:lengths,路径,times,如何,使用,mathcal,pi,CTCLoss
来源: https://www.cnblogs.com/chenkui164/p/16359288.html