读取tfrecord,并写入h5文件
作者:互联网
import tfrecord as tfr
import h5py
import os,sys
import numpy as np
import glob
import pandas as pd
from tqdm import tqdm
class TfrecordWorker():
def __init__(self,tfr_list):
self.info = {"label":[],"typee":[],"shape":[]}
self.data_dir = "raw_data"
self.tfr_list = tfr_list
self.tfr_description = self._parse_description("label_type.csv")
loader = tfr.tfrecord_loader(self.tfr_list[0], None, self.tfr_description )
for record in loader:
for key in record.keys():
self.info['label'].append(key)
self.info['typee'].append(type(record[key][0]))
self.info['shape'].append(record[key].shape)
self.attr_size = len(self.info['label'])
self.data_size = len(self.tfr_list)
print(f"总共有{self.attr_size}个属性")
print(f"总共有{self.data_size}个tfrecord文件")
def create_h5f(self, h5path="./data.h5"):
self.h5f = h5py.File(h5path, 'w')
self.dset = {}
for i in range(self.attr_size):
label = self.info["label"][i]
typee = self.info["typee"][i]
shape = self.info["shape"][i]
self.dset[label] = self.h5f.create_dataset(label,
shape=[self.data_size, shape[0]],
compression=None,
dtype=typee)
self.dset["name"] = self.h5f.create_dataset("name",
shape=[self.data_size],
compression=None,
dtype=h5py.special_dtype(vlen=str))
def write_h5f(self):
for idx,tfr_path in tqdm(enumerate(self.tfr_list)):
self._write_one_item(tfr_path, idx)
# if idx>5:
# break
def close_h5f(self):
self.h5f.close()
def _write_one_item(self, tfr_path, idx):
loader = tfr.tfrecord_loader(tfr_path, None, self.tfr_description )
for record in loader:
for key in record.keys():
content = record[key]
self.dset[key][idx] = content
self.dset["name"][idx] = tfr_path.split("/")[-1]
def _parse_description(self, csv_path):
label_type = pd.read_csv(csv_path, usecols=["label","type"])
description = {}
for _, row in label_type.iterrows():
description[str(row['label']).strip()] = str(row['type']).strip()
return description
def start(files, savename):
worker = TfrecordWorker(files)
worker.create_h5f(savename)
worker.write_h5f()
worker.close_h5f()
start(glob.glob("raw_data/*fold0*.tfrecord"),"fold0.h5")
start(glob.glob("raw_data/*fold1*.tfrecord"),"fold1.h5")
start(glob.glob("raw_data/*fold2*.tfrecord"),"fold2.h5")
start(glob.glob("raw_data/*fold3*.tfrecord"),"fold3.h5")
f = h5py.File('fold0.h5', 'r')
print('--iterms: ', len(f.keys()), f.keys())
name = f['name']
print(name[:])
标签:tfrecord,h5f,读取,self,label,h5,data,tfr 来源: https://www.cnblogs.com/geoli/p/15983442.html