其他分享
首页 > 其他分享> > 小样本数据生成

小样本数据生成

作者:互联网

前言

有许多场景,我们只有少量样本,而训练网络模型时是需要吃大量数据的,一种方法就是迁移学习,比如预训练模型等方法,但是这里我们从另外一个角度来看看,那就是数据增强,关于数据增强方法已有很多,这里说说一些常见的方法,尤其是最新的(当前时间是2021.1.28)一些方法。

传统常见的

比如对于文本数据来说,最容易的就是shuffle, drop, 同义词替换,回译,随机插入等等,这些都是一些最基本的方法,依据token 在本身上面做些扰动改变来数据增加,更多的可以看一下nlpcda这个python包

https://github.com/425776024/nlpcda

对于图像来说就是旋转、平移、裁剪等等。

Free Lunch for Few-shot Learning: Distribution Calibration

发表于2021 ICLR的一篇论文,有代码,主要利用的理论依据就是大数据基类的均值和方差这两个统计量,假如没有大数据基类或者面对跨域Cross Domain问题就恐怕不行了。

论文:https://arxiv.org/pdf/2101.06395.pdf

作者本人解读:https://zhuanlan.zhihu.com/p/344531704?utm_source=wechat_session&utm_medium=social&utm_oi=715833494529380352&utm_campaign=shareopn

一些代码解读:https://zhuanlan.zhihu.com/p/346956374

其中的N-way K-shot任务,可以参考

https://zhuanlan.zhihu.com/p/151311431

其主要思路就是:先在整体数据集(大的数据集)上面统计每一个类的均值base_mean和协方差base_cov(假设有10个类,那就是10对均值和协方差),然后在生成的时候,对于每一个类的每一个样本,计算它和10个类的均值的差值,选取最接近的的k个(自己设置的一个参数,比如2),这样就选出两个均值和协方差,然后当前这个样本的均值就更新为这个样本的+选出的两个样本均值一共三个,对这三个取均值作为新的均值,方差是选出这两个协方差的均值,这样就得到当前这个样本更正后的均值和方差,依次构建一个高斯分布,然后从这个分布中抽取m个样本(就是生产m个样本吧),假设小样本是20个,那么扩充完后就是20*m个啦,而且这新生成的20*m个样本是根据数据分布来生成的。

需注意,这里用到一个大的前提就是base_mean和base_cov,这个是大数据的先验知识,如果没有一些基数据分布供我们统计,就单单凭借一个小样本其实是生成不了数据集的,我们用的先验知识就是大数据集上面的统计量,具体到这里就是均值和方差

这里笔者在isir数据集上面简单做了一个实验:

每一个颜色代表一个类别,一共3类,五角星代表小样本,每个类我们取5个,用这5个来生成数据,圆圈就是真实的数据,每个类50个真实样本,三角形是生成的样本每个样本(即五角星)生成10个,一共生成3*5*10=150个样本。

全部代码:

import pandas as pd
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


def load_iris_dataset():
    iris=datasets.load_iris()
    feature = iris.data
    dataset = pd.DataFrame(feature)
    dataset.columns = iris.feature_names
    dataset["label"] = iris.target
    return dataset, iris.feature_names


def distribution_calibration(query, base_means, base_cov, k, alpha=0.5):
    dist = []
    for i in range(len(base_means)):
        dist.append(np.linalg.norm(query-base_means[i]))
    index = np.argpartition(dist, k)[:k]
    mean = np.concatenate([np.array(base_means)[index], query[np.newaxis, :]])
    calibrated_mean = np.mean(mean, axis=0)
    calibrated_cov = np.mean(np.array(base_cov)[index], axis=0)+alpha

    return calibrated_mean, calibrated_cov


def get_min_sample(dataset, feature_name):    
    base_means = []
    base_cov = []

    df_group = dataset.groupby("label")
    print(" label num :" + str(len(df_group)))
    result = pd.DataFrame(columns=feature_name+["label"])
    for label, df in df_group:
        #抽取少量的一部分样本, 10%的比例
        df_temp = df.sample(frac=0.1,axis=0)
        result = result.append(df_temp, ignore_index=True)

        feature = df[feature_name].values
        mean = np.mean(feature, axis=0)
        cov = np.cov(feature.T)
        base_means.append(mean)
        base_cov.append(cov)
    return result, base_means, base_cov


def generate_dataset(result, base_means, base_cov, feature_name, num_sampled = 10, return_origion=False):
    df_group = result.groupby("label")
    
    
    sampled_data = []
    sampled_label = []
    
    sample_num = result.shape[0]
    feature = result[feature_name].values
    label = result["label"].values
    for i in range(sample_num):
        mean, cov = distribution_calibration(feature[i], base_means, base_cov, k=1)
        sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
        sampled_label.extend([label[i]]*num_sampled)
    
    sampled_data = np.concatenate([sampled_data[:]]).reshape(result.shape[0] * num_sampled, -1)
    result_aug = pd.DataFrame(sampled_data)
    result_aug.columns = feature_name
    result_aug["label"] = sampled_label
    
    if return_origion:
        result_aug = result_aug.append(result, ignore_index=True)
    
    '''
    #返回包括Support set
    if return_origion:
        result_aug = pd.DataFrame(columns=iris.feature_names+["label"])
        for label, df in df_group:
            feature = df[iris.feature_names].values
            label = df["label"].values
            num = feature.shape[0]
            
            for i in range(num):
                mean, cov = distribution_calibration(feature[i], base_means, base_cov, k=2)
                sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
                sampled_label.extend([label[i]]*num_sampled)
            sampled_data = np.concatenate([sampled_data[:]]).reshape(num * num_sampled, -1)
            X_aug = np.concatenate([feature, sampled_data])
            Y_aug = np.concatenate([label, sampled_label])

            df_aug = pd.DataFrame(X_aug)
            df_aug.columns = iris.feature_names
            df_aug["label"] = Y_aug

            result_aug = result_aug.append(df_aug, ignore_index=True)
    else:
        #返回不包括Support set
        sampled_data = []
        sampled_label = []
        for label, df in df_group:
            feature = df[iris.feature_names].values
            label = df["label"].values
            num = feature.shape[0]
            for i in range(num):
                mean, cov = distribution_calibration(feature[i], base_means, base_cov, k=2)
                sampled_data.append(np.random.multivariate_normal(mean=mean, cov=cov, size=num_sampled))
                sampled_label.extend([label[i]]*num_sampled)       
        sampled_data = np.concatenate([sampled_data[:]]).reshape(result.shape[0] * num_sampled, -1)
        result_aug = pd.DataFrame(sampled_data)
        result_aug.columns = iris.feature_names
        result_aug["label"] = sampled_label    
    '''
    
    return result_aug
    
    
def Visualise(base_class_dataset, result, result_aug, feature_name):
    dataset = base_class_dataset.append(result)
    dataset = dataset.append(result_aug)
    feature = dataset[feature_name].values
    label = dataset.label.values
    
    label_map = {"0":'red', "1":'black', "2":'peru'}
    
    
    transform = TSNE  # PCA
    trans = transform(n_components=2)
    feature= trans.fit_transform(feature)
    
    base_class_dataset_feature = feature[:base_class_dataset.shape[0]]
    result_feature = feature[base_class_dataset.shape[0]:base_class_dataset.shape[0]+result.shape[0]]
    result_aug_feature = feature[base_class_dataset.shape[0]+result.shape[0]:]
    
    
    base_class_dataset_label = label[:base_class_dataset.shape[0]]
    base_class_dataset_label_colours = [label_map[str(target)] for target in base_class_dataset_label]
    
    result_label = label[base_class_dataset.shape[0]:base_class_dataset.shape[0]+result.shape[0]]
    result_label_colours = [label_map[str(target)] for target in result_label]
    
    result_aug_label = label[base_class_dataset.shape[0]+result.shape[0]:]
    result_aug_label_colours = [label_map[str(target)] for target in  result_aug_label]
     
    
    
    plt.figure(figsize=(20, 15))
    plt.axes().set(aspect="equal")
    plt.scatter(base_class_dataset_feature[:, 0], base_class_dataset_feature[:, 1], c=base_class_dataset_label_colours, marker='o', s=100)
    plt.scatter(result_aug_feature[:, 0], result_aug_feature[:, 1], c= result_aug_label_colours, marker='^', s=100)
    plt.scatter(result_feature[:, 0], result_feature[:, 1], c=result_label_colours, marker='*', s=800)
    
    plt.title("{} visualization of node embeddings".format("All_class"))
    #plt.savefig("AllNode.png")
    plt.show()
    for i in label_map:
        print(label_map[i])
        print(i)
        print("*********************************")
        

def main():
    #获取iris原数据集
    base_class_dataset, feature_name =load_iris_dataset()
    print(base_class_dataset.shape)
    print("******"*5)
    print(base_class_dataset.head(3))
    
    #获取数据分布统计量, 参数k和alpha就是超参, k和类别数有关,如果总类别数(num_label)多,这个
    #可以调大一点,总之k<num_label
    result, base_means, base_cov = get_min_sample(base_class_dataset, feature_name)
    #生成数据,
    result_aug = generate_dataset(result, base_means, base_cov, feature_name, num_sampled = 10)
    print(result.shape)
    print(result_aug.shape)
    #可视化
    Visualise(base_class_dataset, result, result_aug, feature_name)


if __name__=='__main__':
    main()

标签:dataset,sampled,样本,feature,label,base,result,生成,数据
来源: https://blog.csdn.net/weixin_42001089/article/details/113307918