其他分享
首页 > 其他分享> > 读取tfrecord,并写入h5文件

读取tfrecord,并写入h5文件

作者:互联网

import tfrecord as tfr
import h5py
import os,sys
import numpy as np
import glob
import pandas as pd
from tqdm import tqdm
class TfrecordWorker():
    def __init__(self,tfr_list):
        self.info = {"label":[],"typee":[],"shape":[]}
        self.data_dir = "raw_data"
        self.tfr_list = tfr_list
        self.tfr_description = self._parse_description("label_type.csv")
        loader = tfr.tfrecord_loader(self.tfr_list[0], None, self.tfr_description  )
        for record in loader:
            for key in record.keys():
                self.info['label'].append(key)
                self.info['typee'].append(type(record[key][0]))
                self.info['shape'].append(record[key].shape)
        self.attr_size = len(self.info['label'])
        self.data_size = len(self.tfr_list)
        print(f"总共有{self.attr_size}个属性")
        print(f"总共有{self.data_size}个tfrecord文件")

    def create_h5f(self, h5path="./data.h5"):
        self.h5f = h5py.File(h5path, 'w')
        self.dset = {}
        for i in range(self.attr_size):
            label = self.info["label"][i]
            typee = self.info["typee"][i]
            shape = self.info["shape"][i]
            self.dset[label] = self.h5f.create_dataset(label,
                              shape=[self.data_size, shape[0]],
                              compression=None,
                              dtype=typee)

        self.dset["name"] = self.h5f.create_dataset("name",
                            shape=[self.data_size],
                            compression=None,
                            dtype=h5py.special_dtype(vlen=str))
    def write_h5f(self):
        for idx,tfr_path in tqdm(enumerate(self.tfr_list)):
            self._write_one_item(tfr_path, idx)
            # if idx>5:
            #     break
            
        
    def close_h5f(self):
        self.h5f.close()

    def _write_one_item(self, tfr_path, idx):
        loader = tfr.tfrecord_loader(tfr_path, None, self.tfr_description  )
        for record in loader:
            for key in record.keys(): 
                content = record[key]
                self.dset[key][idx] = content
        self.dset["name"][idx] = tfr_path.split("/")[-1]

    def _parse_description(self, csv_path):
        label_type = pd.read_csv(csv_path, usecols=["label","type"])
        description = {}
        for _, row in label_type.iterrows():
            description[str(row['label']).strip()] = str(row['type']).strip()
        return description



def start(files, savename):
    worker = TfrecordWorker(files)
    worker.create_h5f(savename)
    worker.write_h5f()
    worker.close_h5f()

start(glob.glob("raw_data/*fold0*.tfrecord"),"fold0.h5")
start(glob.glob("raw_data/*fold1*.tfrecord"),"fold1.h5")
start(glob.glob("raw_data/*fold2*.tfrecord"),"fold2.h5")
start(glob.glob("raw_data/*fold3*.tfrecord"),"fold3.h5")
f = h5py.File('fold0.h5', 'r')
print('--iterms: ', len(f.keys()), f.keys())
name = f['name']
print(name[:])

标签:tfrecord,h5f,读取,self,label,h5,data,tfr
来源: https://www.cnblogs.com/geoli/p/15983442.html