其他分享
首页 > 其他分享> > transformer代码笔记----pre_process.py

transformer代码笔记----pre_process.py

作者:互联网

import os
import pickle
from tqdm import tqdm
from config import wav_folder, transcript_file, pickle_file
from utils import ensure_folder


def get_data(split):
    print('getting {} data...'.format(split)) #对获取的数据名打印
 
    global VOCAB   #定义全局变量
    with open(transcript_file, 'r', encoding='utf-8') as file: #打开文件transcript_file,仅对其读操作,重命名为file
        lines = file.readlines() #逐行读取文件内容

    tran_dict = dict() #创建空字典
    for line in lines:  #迭代file文件中的每一行
        tokens = line.split() #将一行的输入进行切分。str.split(str="", num=string.count(str)):str为分隔符,默认空格;num为切分次数,默认全切分
        key = tokens[0]
        trn = ''.join(tokens[1:]) #'_'.join(sequence):将sequence中的元素以'_'连接形成一个新的元素
        tran_dict[key] = trn   # tran_dict: {'BAC0009123': wav1.wav, ...}

    samples = []

    folder = os.path.join(wav_folder, split)    # data/data_aishell/wav/train  os.path.join():连接路径名,以/连接
    ensure_folder(folder)    # 确保floder是一个目录,如果不存在该路径下的目录就生成一个新的目录
    #os.listdir():以列表的形式提取路径下的文件。os.path.isdir():判断是否存在该文件。最终dirs中以列表形式存储folder路径下的所有文件
    dirs = [os.path.join(folder, d) for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]  # data/data_aishell/wav/train/S0003
    for dir in tqdm(dirs): 
    #Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
        files = [f for f in os.listdir(dir) if f.endswith('.wav')]    # [wav1, wav2, .....]
        #endswith() 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回 True,否则返回 False。

        for f in files:
            wave = os.path.join(dir, f)  # data/data_aishell/wav/train/S0003/wav1.wav

            key = f.split('.')[0] #切分f,取第一个元素:wav1
            if key in tran_dict:
                trn = tran_dict[key]
                trn = list(trn.strip()) + ['<eos>'] #获取数据,并在每行数据后加结束标志
                #strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。

                for token in trn:
                    build_vocab(token)

                trn = [VOCAB[token] for token in trn]

                samples.append({'trn': trn, 'wave': wave}) #append() 方法用于在列表末尾添加新的对象。

    print('split: {}, num_files: {}'.format(split, len(samples)))
    return samples


def build_vocab(token):
    global VOCAB, IVOCAB
    if not token in VOCAB: #将token及index添加到IVOCAB和VOCAB中
        next_index = len(VOCAB)
        VOCAB[token] = next_index
        IVOCAB[next_index] = token


if __name__ == "__main__":
    VOCAB = {'<sos>': 0, '<eos>': 1}
    IVOCAB = {0: '<sos>', 1: '<eos>'}

    data = dict()
    data['VOCAB'] = VOCAB
    data['IVOCAB'] = IVOCAB
    data['train'] = get_data('train')
    data['dev'] = get_data('dev')
    data['test'] = get_data('test')

    with open(pickle_file, 'wb') as file:
        pickle.dump(data, file)

    print('num_train: ' + str(len(data['train'])))
    print('num_dev: ' + str(len(data['dev'])))
    print('num_test: ' + str(len(data['test'])))
    print('vocab_size: ' + str(len(data['VOCAB'])))

 

标签:pre,VOCAB,transformer,trn,process,os,file,folder,data
来源: https://www.cnblogs.com/Uriel-w/p/15426160.html