其他分享
首页 > 其他分享> > transformers 之Trainer对应的数据加载

transformers 之Trainer对应的数据加载

作者:互联网

基础信息说明

数据加载

Trainer的数据加载方式主要分为两种:基于torch.utils.data.Dataset的方式加载 和 基于huggingface自带的Datasets的方式(下文用huggingface / Datasets表示)加载。以下是一些需要注意的点:(1)Seq2SeqTrainer()的train_dataset和eval_dataset参数的所传实参应为字典类型;(2)该字典实参的keys应当覆盖模型运行所需要的数据参数(本文需要包括的有:'input_ids', 'attention_mask', 'labels');(3)使用huggingface / Datasets方法加载时,传给train_dataset和eval_dataset的字典实参中,多余的key(未在模型运行所需输入参数列表中)及其相关数据数,将会在训练之前被剔除。

torch.utils.data.Dataset

重载Dataset类(dataset.py)

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset

class CDNDataset(Dataset):
    def __init__(self, samples):
        super(CDNDataset, self).__init__()
        self.samples = samples

    def __getitem__(self, ite):
        res = {k_: v_[ite]for k_, v_ in self.samples.items()}
        return res

    def __len__(self):
        return len(self.samples['labels'])

加载引用(main.py 后文代码同属于本文件)

from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from dataset import CDNDataset

读取数据

# 读取训练集
with open('raw_data/txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
    train_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
    # 将tokenized的中文序列对应的input_ids作为输入数据的标签
    train_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128,
                                     												padding=True,truncation=True)["input_ids"]
    fr_en.close()
    fr_zh.close()
train_data = CDNDataset(train_data)

# 读取验证集
with open('raw_data/test_txt_en.txt', 'r', encoding='utf-8') as fr_en, open('raw_data/test_txt_zh.txt', 'r', encoding='utf-8') as fr_zh:
    dev_data = tokenizer([str_.strip() for str_ in fr_en.readlines()], max_length=128, padding=True,truncation=True)
    dev_data['labels'] = tokenizer([str_.strip() for str_ in fr_zh.readlines()], max_length=128, 
                                   													padding=True,truncation=True)["input_ids"]
    fr_en.close()
    fr_zh.close()
 dev_data = CDNDataset(dev_data)

huggingface / Datasets

修改main.py中数据集读取部分的代码

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
"""利用load_dataset()来读取数据:
	- 该方法支持.txt、.csv、.json等文件格式
	- 返回结果是一个字典类型
	- 读取.txt文件时,若不指定名称,这key为"text", 且会返回文本中的样本数(段落数)
	- 在读取.json文件时,若所有样本放在一个josn文件中,则返回的样本数为1(无法优雅地调用train_test_split()进行数据集分割),名称为默认名或者最层字典所       对应的keys;
	- 将每个json文件仅存放一个样本,并把这些文件放在某一目录,可使利用load_dataset()正确计算出样本数。但该目录下每个.json文件命名风格要一致(例如:txt1.json、txt2.json、、、),文件名差异较大的话,系统会只读取某一类命名格式相近的文件中的数据。
	""" 
books = load_dataset("raw_data", data_dir='test_en', name='translation')

books = books["train"].train_test_split(test_size=0.15)

source_lang = "en"
target_lang = "zh"
prefix = "translate English to Chinese: "  # 其实我也还没搞懂为啥要加这样一个前缀


def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_books = books.map(preprocess_function, batched=True)

模型及参数加载

tokenizer = AutoTokenizer.from_pretrained("opus-mt-en-zh")
model = AutoModelForSeq2SeqLM.from_pretrained("opus-mt-en-zh")
#使用huggingface/Datasets方式加载数据时,可以用DataCollator达到批处理的效果
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) #用torch.utils.data.Dataset方式加载时,不需要

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    fp16=True,
)

模型训练

本文以Seq2SeqTrainer作为实例来进行介绍。

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data, 
    eval_dataset=dev_data,
    tokenizer=tokenizer,
    data_collator=data_collator, #用torch.utils.data.Dataset方式加载时,此参数不需要
)

补充说明

标签:Trainer,__,fr,en,zh,True,transformers,data,加载
来源: https://www.cnblogs.com/teanon/p/16583085.html