其他分享
首页 > 其他分享> > 读代码:Presence

读代码:Presence

作者:互联网

ReadMe

demo.py 是一个简单的demo,1)将位置作为输入,返回一个对所有分类是否会在该位置存在的预测,或2)为一个感兴趣的类别生成一个密集的预测。

geo_prior/ 地理先验,包括主要的训练和评估模型的代码
gen_figs/ 生成图像,包括重新创建paper中的图像
pre_process/ 预处理,包括训练图像分类器和保存特征/预测
web_app/ 包括基于web的模型预测可视化的代码

demo.py 

"""
Demo that either 1) takes a location as input and returns a prediction indicating
the likelihood that each category is present there, or 2) takes a category ID as
input and generates a prediction for each location on the globe.
"""

1)以一个位置作为输入,返回对每个类在这里存在的可能性的预测 

或者2)以一个类别ID作为输入,为地球的每个位置生成一个预测

就是有两种方式,一种输入位置经纬度坐标,返回的是在这个位置,每个物种存在的概率,或者输入一个物种类别的ID,为地球上的每个经纬坐标生成该物种存在的预测。

import argparse
import numpy as np
import json
import matplotlib.pyplot as plt
import torch
import os
from six.moves import urllib

from geo_prior import models
from geo_prior import utils
from geo_prior import grid_predictor as grid

下载模型 download_model

def download_model(model_url, model_path):

    # Download pre-trained model if it is not currently available.
    if not os.path.isfile(model_path):
        try:
            print('Downloading model from: ' + model_url)
            urllib.request.urlretrieve(model_url, model_path)
        except:
            print('Failed to download model from: ' + model_url)

输入参数是model的url链接,和model的路径。这个函数用于模型目前不可用时,下载预训练的模型,如果该路径存在,则使用urllib.request.urlretrieve从传参的url链接下载模型。

主函数 main

def main(args):

    download_model(args.model_url, args.model_path)
    print('Loading model: ' + args.model_path)
    net_params = torch.load(args.model_path, map_location='cpu')
    params = net_params['params']
    model = models.FCNet(num_inputs=params['num_feats'], num_classes=params['num_classes'],
                         num_filts=params['num_filts'], num_users=params['num_users']).to(params['device'])
    model.load_state_dict(net_params['state_dict'])
    model.eval()

 首先调用上面的download_model函数,接着从路径中加载下载的模型,调用models.FCNet。

# load class names
    with open(args.class_names_path) as da:
        class_data = json.load(da)

    if args.demo_type == 'location':
        # convert coords to torch
        coords = np.array([args.longitude, args.latitude])[np.newaxis, ...]
        obs_coords = utils.convert_loc_to_tensor(coords, params['device'])
        obs_time = torch.ones(coords.shape[0], device=params['device'])*args.time_of_year*2 - 1.0
        loc_time_feats = utils.encode_loc_time(obs_coords, obs_time, concat_dim=1, params=params)

        print('Making prediction ...')
        with torch.no_grad():
            pred = model(loc_time_feats)[0, :]
        pred = pred.cpu().numpy()

        num_categories = 25
        print('\nTop {} likely categories for location {:.4f}, {:.4f}:'.format(num_categories, coords[0,0], coords[0,1]))
        most_likely = np.argsort(pred)[::-1]
        for ii, cls_id in enumerate(most_likely[:num_categories]):
            print('{}\t{}\t{:.3f}'.format(ii, cls_id, np.round(pred[cls_id], 3)) + \
                '\t' + class_data[cls_id]['our_name'] + ' - ' + class_data[cls_id]['preferred_common_name'])

加载类名,这里有一个if-else分支,根据demo_type的不同,分为“location”和“map”。 

如果是location:

with语句的原理

    elif args.demo_type == 'map':
        # grid predictor - for making dense predictions for each lon/lat location
        gp = grid.GridPredictor(np.load('data/ocean_mask.npy'), params, mask_only_pred=True)

        if args.class_of_interest == -1:
            args.class_of_interest = np.random.randint(len(class_data))
        print('Selected category: ' + class_data[args.class_of_interest]['our_name'] +\
            ' - ' + class_data[args.class_of_interest]['preferred_common_name'])

        print('Making prediction ...')
        grid_pred = gp.dense_prediction(model, args.class_of_interest, time_step=args.time_of_year)

        op_file_name = class_data[args.class_of_interest]['our_name'].lower().replace(' ', '_') + '.png'
        print('Saving prediction to: ' + op_file_name)
        plt.imsave(op_file_name, 1.0-grid_pred, cmap='afmhot', vmin=0, vmax=1

如果demo_type是"map":为每个经纬度位置生成密集预测

if __name__ == "__main__":

    info_str = '\nPresence-Only Geographical Priors for Fine-Grained Image Classification.\n\n' + \
               'This demo can be run in one of two ways:\n' + \
               '1) Give a location and get a list of most likely classes there e.g\n' + \
               '   python demo.py location --longitude -118.1445155 --latitude 34.1477849 --time_of_year 0.5\n' + \
               'Input coordinates should be in decimal degrees i.e. ' + \
               'Longitude: [-180, 180], Latitude: [-90, 90], and Time of year [0, 1].\n\n' + \
               '2) Give a category ID as input and get a prediction for each location on the globe for that category e.g.\n' + \
               '   python demo.py map --class_of_interest 3731\n' + \
               'If class_of_interest is not specified a random one will be selected.\n\n'

    model_path = 'models/model_inat_2018_full_final.pth.tar'
    model_url  = 'http://www.vision.caltech.edu/~macaodha/projects/geopriors/model_inat_2018_full_final.pth.tar'
    class_names_path = 'web_app/data/categories2018_detailed.json'

info_str:细粒度图像分类的仅存在地理先验。这个demo可以用以下两种方法的一种运行:

  1. 给定一个位置,得到一个最可能在这里存在的物种list。例如:
    python demo.py location --longitude -118.1445155 --latitude 34.1477849 --time_of_year 0.5

    输入的坐标需要用数字表示经纬度,即Longitude:[-180,180],Ltitude:[-90,90],Time of year[0,1]

  2. 给定一个类别ID作为输入,得到全球每个位置上这个物种存在的预测。例如:
    python demo.py map --class_of_interest 3731

    如果没有指定类别,就随机生成一个类ID。

给出model_path,model_url和class_names_path 

标签:Presence,demo,代码,args,params,time,model,class
来源: https://blog.csdn.net/weixin_39627422/article/details/120456504