如何生成锚框并在图片可视化
作者:互联网
import torch from torch._C import Size from d2l import torch as d2l torch.set_printoptions(2) #精简打印 def multibox_prior(data,sizes,ratios): #生成以每个像素为中心具有不同形状的锚框 in_height,int_width = data.shape[-2:] device,num_sizes,num_ratios = data.device,len(sizes),len(ratios) boxes_per_pixel = (num_sizes+num_ratios-1) size_tensor = torch.tensor(sizes,device=device) ratio_tensor = torch.tensor(ratios,device=device) #为了将锚点移动到像素中,需要设置偏移量 #每个像素的高为1宽为1,选择偏移我们的中心0.5 offset_h,offset_w = 0.5,0.5 steps_h = 1.0/in_height steps_w = 1.0/int_width #生成锚框的所有中心点 center_h = (torch.arange(in_height,device=device)+offset_h)*steps_h center_w = (torch.arange(int_width,device=device)+offset_w)*steps_w shift_y,shift_x = torch.meshgrid(center_h,center_w) shift_y,shift_x = shift_y.reshape(-1),shift_x.reshape(-1) #生成“boxes_per_pixel"个高和宽 #之后用于创建锚框的四角坐标(xmin,xmax,ymin,ymax) w = torch.cat((size_tensor*torch.sqrt(ratio_tensor[0]), sizes[0]*torch.sqrt(ratio_tensor[1:])))\ *in_height/int_width h = torch.cat((size_tensor/torch.sqrt(ratio_tensor[0]), sizes[0]/torch.sqrt(ratio_tensor[1:]))) # 除以2 来获得半宽和半高 anchor_manipulations = torch.stack((-w,-h,w,h)).T.repeat( in_height*int_width,1)/2 #每个中心点都将有“boxes_per_pixel"个锚框 #所以生成含有锚框中心的网格,重复了“boxes_per_pixel"次 out_grid = torch.stack([shift_x,shift_y,shift_x,shift_y], dim=1).repeat_interleave(boxes_per_pixel,dim=0) output = out_grid+anchor_manipulations return output.unsqueeze(0) img = d2l.plt.imread('../img/catdog.jpg') h,w = img.shape[:2] print(h,w) X = torch.rand(size=(1,3,h,w)) Y = multibox_prior(X,sizes = [0.75,0.5,0.25],ratios=[1,2,0.5]) Y.shape() boxes = Y.reshape(h,w,5,4) boxes[250,250,0,:] #显示所有边界框 def show_bboxes(axes,bboxes,labels=None,colors=None): def _make_list(obj,default_values=None): if obj is None: obj = default_values elif not isinstance(obj,(list,tuple)): obj=[obj] return obj labels = _make_list(labels) colors = _make_list(colors,['b','g','r','m','c']) for i , bbox in enumerate(bboxes): color = colors[i%len(colors)] rect = d2l.bbox_to_rect(bbox.detach().numpy(),color) axes.add_patch(rect) if labels and len(labels)>i: text_color = 'k' if color =='w' else 'w' axes.text(rect.xy[0],rect.xy[1],labels[i], va = 'center',ha = 'center',fontsize = 9,color = text_color, bbox = dict(facecolor=color,lw =0))
#原链接:https://zh-v2.d2l.ai/chapter_preface/index.html
标签:tensor,sizes,shift,锚框,torch,生成,boxes,可视化,device 来源: https://www.cnblogs.com/chuxinbubian/p/15515455.html