semi-supervised
作者:互联网
1.给无标签数据进行标记
def get_pseudo_labels(dataset, model, threshold=0.7):
# 给6786个无标签数据标记
# This functions generates pseudo-labels of a dataset using given model.
# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
# You are NOT allowed to use any models trained on external data for pseudo-labeling.
device = "cuda" if torch.cuda.is_available() else "cpu"
# Construct a data loader.
# print(len(dataset))
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# print(len(data_loader))
# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax(dim=-1)
imgList = []
labelList = []
# Iterate over the dataset by batches.
for batch in tqdm(data_loader):
img, _ = batch
# Forward the data
# Using torch.no_grad() accelerates the forward process.
with torch.no_grad():
logits = model(img.to(device))
logits = softmax(logits)
# ---------- TODO ----------
# Filter the data and construct a new dataset.
score_list, class_list = logits.max(dim=-1)
score_list, class_list = score_list.cpu().numpy(), class_list.cpu().numpy()
# print(score_list, class_list)
score_filter = score_list > threshold
score_list, class_list = score_list[score_filter], class_list[score_filter]
imgList.append(img[score_filter])
labelList.append(class_list)
# print(imgList, labelList)
# print(type(imgList), type(imgList[0]))
# print(type(labelList), type(labelList[0]))
dataset = noLabeledDataset(imgList, labelList)
del imgList
del labelList
del data_loader
# # Turn off the eval mode.
model.train()
return dataset
2.给新标记的数据构建数据集
class noLabeledDataset(Dataset):
def __init__(self, imgList, labelList):
# torch.cat((x, x, x), 0)
#
n = len(imgList)
x = torch.cat(([imgList[i] for i in range(n)]), 0)
del imgList
y = [label for labels in labelList for label in labels]
# y = torch.from_numpy(np.array(y))
del labelList
self.len = x.shape[0]
self.x_data = x
self.y_data = y
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
3.和训练集 进行合并
# The number of training epochs.
n_epochs = 20
# Whether to do semi-supervised learning.
do_semi = True
for epoch in range(n_epochs):
# ---------- TODO ----------
# In each epoch, relabel the unlabeled dataset for semi-supervised learning.
# Then you can combine the labeled dataset and pseudo-labeled dataset for the training.
if do_semi:
# Obtain pseudo-labels for unlabeled data using trained model.
pseudo_set = get_pseudo_labels(unlabeled_set, model)
# Construct a new dataset and a data loader for training.
# This is used in semi-supervised learning only.
concat_dataset = ConcatDataset([train_set, pseudo_set])
train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
del concat_dataset
del pseudo_set
print(len(train_loader)
标签:semi,imgList,list,dataset,score,supervised,labelList,data 来源: https://blog.csdn.net/LeCarry/article/details/121136258