pytorh dataloader 迭代类型数据链式处理分析
作者:互联网
https://github.com/wenet-e2e/wenet wenet官方代码,在最新的UIO模式中加入链式处理数据
import time
import random
class Process():
def __init__(self ,data ,f):
self.data = data
self.f = f
def __iter__(self):
return self.f(iter(self.data))
# data = [[j + str(i) for i in range(10)] for j in ['a','b', 'c'] ]
data = ['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9','b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9','c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']
def travel(d):
for i in d:
yield i
def shuffle(d , sf_size=15):
buf = []
for i in d:
buf.append(i)
if len(buf) >= sf_size:
random.shuffle(buf)
for j in buf:
# print('shuffle',j)
yield j
buf = []
for k in buf :
yield k
def sort(d):
buf = []
for i in d:
buf.append(i)
if len(buf) >= 5:
for i in buf:
# print('sort' , i )
yield i
buf = []
for k in buf:
yield k
def batch(d):
buf = []
for i in d:
buf.append(i)
if len(buf) >= 4:
for i in buf:
# print('batch' , i )
yield i
buf = []
p = Process(data , travel)
p = Process(p , shuffle)
# p = Process(p , sort)
p = Process(p , batch)
for i in p:
print(i , 'train')
标签:迭代,Process,buf,self,dataloader,yield,pytorh,data,def 来源: https://www.cnblogs.com/lhx9527/p/15757646.html