Wenet模型流程梳理
作者:互联网
asr_model
-
encoder
input: speech(16,80,183)# 183属于batch中最大元素决定 speech_length text (16,6)# 6由batch最大值决定 text_length
-
make_pad_mask
mask :(16,183)
-
subsampling
input(speech,mask)
-
conv(speech)
torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU() # output (16,256,45,19)
-
self.out: linear
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) # output (16,45,256)
-
self.pos_enc
pos_emb (1,45,256) # output # speech = torch.nn.Dropout(speech) (16,45,256) # pos_emb = torch.nn.Dropout(pos_emb) (1,45,256)
-
-
$subsampling
#output: speech , pos_emb , mask(16,1,45) #x_mask[:, :, :-2:2][:, :, :-2:2]
-
add_optional_chunk_mask
add_optional_chunk_mask
-
标签:45,nn,16,流程,torch,mask,speech,Wenet,梳理 来源: https://www.cnblogs.com/lhx9527/p/16138419.html