其他分享
首页 > 其他分享> > semi-supervised

semi-supervised

作者:互联网

1.给无标签数据进行标记

def get_pseudo_labels(dataset, model, threshold=0.7):
    # 给6786个无标签数据标记
    # This functions generates pseudo-labels of a dataset using given model.
    # It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
    # You are NOT allowed to use any models trained on external data for pseudo-labeling.
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Construct a data loader.
    #     print(len(dataset))
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    #     print(len(data_loader))
    # Make sure the model is in eval mode.
    model.eval()
    # Define softmax function.
    softmax = nn.Softmax(dim=-1)

    imgList = []
    labelList = []
    # Iterate over the dataset by batches.
    for batch in tqdm(data_loader):
        img, _ = batch

        # Forward the data
        # Using torch.no_grad() accelerates the forward process.
        with torch.no_grad():
            logits = model(img.to(device))
        logits = softmax(logits)
        # ---------- TODO ----------
        # Filter the data and construct a new dataset.
        score_list, class_list = logits.max(dim=-1)
        score_list, class_list = score_list.cpu().numpy(), class_list.cpu().numpy()
        #         print(score_list, class_list)
        score_filter = score_list > threshold
        score_list, class_list = score_list[score_filter], class_list[score_filter]

        imgList.append(img[score_filter])
        labelList.append(class_list)

    #     print(imgList, labelList)
    #     print(type(imgList), type(imgList[0]))
    #     print(type(labelList), type(labelList[0]))
    dataset = noLabeledDataset(imgList, labelList)
    del imgList
    del labelList
    del data_loader
    # # Turn off the eval mode.
    model.train()
    return dataset

2.给新标记的数据构建数据集

class noLabeledDataset(Dataset):
    def __init__(self, imgList, labelList):
        # torch.cat((x, x, x), 0)
        #
        n = len(imgList)

        x = torch.cat(([imgList[i] for i in range(n)]), 0)
        del imgList
        y = [label for labels in labelList for label in labels]
        # y = torch.from_numpy(np.array(y))
        del labelList
        self.len = x.shape[0]
        self.x_data = x
        self.y_data = y

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len

3.和训练集 进行合并

# The number of training epochs.
n_epochs = 20

# Whether to do semi-supervised learning.
do_semi = True

for epoch in range(n_epochs):
    # ---------- TODO ----------
    # In each epoch, relabel the unlabeled dataset for semi-supervised learning.
    # Then you can combine the labeled dataset and pseudo-labeled dataset for the training.
    if do_semi:
        # Obtain pseudo-labels for unlabeled data using trained model.

        pseudo_set = get_pseudo_labels(unlabeled_set, model)

        # Construct a new dataset and a data loader for training.
        # This is used in semi-supervised learning only.
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

        del concat_dataset
        del pseudo_set
    print(len(train_loader)

标签:semi,imgList,list,dataset,score,supervised,labelList,data
来源: https://blog.csdn.net/LeCarry/article/details/121136258