【图像检索】resnet50由torch转onnx转openvino
作者:互联网
0.环境
# yolov5的环境 + onnx
onnx==1.9.0
# openvino的环境
openvino_2021.3.394
1.转onnx
torch.onnx.export
2.转openvino
python3 /opt/intel/openvino/deployment_tools/model_optimizer/mo.py --input_model models/resnet50.onnx --output_dir models/resnet50_retrival --reverse_input_channels --input_shape [1,3,224,224] --mean_values [0.485,0.456,0.406] --scale_values [0.229,0.224,0.225] --output Flatten_n
其中,
--mean_values:修改为自己的;
--output Flatten_n:通过netron.app查看。也是你需要输出的对应节点的特征。
3.demo
来自:/opt/intel/openvino_2021.3.394/deployment_tools/open_model_zoo/demos/image_retrieval_demo/python/image_retrieval_demo.py
#!/usr/bin/env python3
"""
Copyright (c) 2019 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
from pathlib import Path
import sys
import time
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
from image_retrieval_demo.image_retrieval import ImageRetrieval
from image_retrieval_demo.common import central_crop
from image_retrieval_demo.visualizer import visualize
from image_retrieval_demo.roi_detector_on_video import RoiDetectorOnVideo
sys.path.append(str(Path(__file__).resolve().parents[2] / 'common/python'))
import monitors
from images_capture import open_images_capture
INPUT_SIZE = 224
def build_argparser():
""" Returns argument parser. """
parser = ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS,
help='Show this help message and exit.')
args.add_argument('-m', '--model',
help='Required. Path to an .xml file with a trained model.',
required=True, type=str)
args.add_argument('-i', '--input', required=True,
help='Required. Path to a video file or a device node of a web-camera.')
args.add_argument('--loop', default=False, action='store_true',
help='Optional. Enable reading the input in a loop.')
args.add_argument('-o', '--output', required=False,
help='Optional. Name of output to save.')
args.add_argument('-limit', '--output_limit', required=False, default=1000, type=int,
help='Optional. Number of frames to store in output. '
'If 0 is set, all frames are stored.')
args.add_argument('-g', '--gallery',
help='Required. Path to a file listing gallery images.',
required=True, type=str)
args.add_argument('-gt', '--ground_truth',
help='Optional. Ground truth class.',
type=str)
args.add_argument('-d', '--device',
help='Optional. Specify the target device to infer on: CPU, GPU, FPGA, HDDL '
'or MYRIAD. The demo will look for a suitable plugin for device '
'specified (by default, it is CPU).',
default='CPU', type=str)
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. Absolute path to "
"a shared library with the kernels implementations.", type=str,
default=None)
args.add_argument('--no_show', action='store_true',
help='Optional. Do not visualize inference results.')
args.add_argument('-u', '--utilization_monitors', default='', type=str,
help='Optional. List of monitors to show initially.')
return parser
def compute_metrics(positions):
''' Computes top-N metrics. '''
top_1_acc = 0
top_5_acc = 0
top_10_acc = 0
for position in positions:
if position < 1:
top_1_acc += 1
if position < 5:
top_5_acc += 1
if position < 10:
top_10_acc += 1
mean_pos = np.mean(positions)
if positions:
log.info("result: top1 {0:.2f} top5 {1:.2f} top10 {2:.2f} mean_pos {3:.2f}".format(
top_1_acc / len(positions), top_5_acc / len(positions), top_10_acc / len(positions),
mean_pos))
return top_1_acc, top_5_acc, top_10_acc, mean_pos
def time_elapsed(func, *args):
""" Auxiliary function that helps measure elapsed time. """
start_time = time.perf_counter()
res = func(*args)
elapsed = time.perf_counter() - start_time
return elapsed, res
def main():
""" Main function. """
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
img_retrieval = ImageRetrieval(args.model, args.device, args.gallery, INPUT_SIZE,
args.cpu_extension)
cap = open_images_capture(args.input, args.loop)
if cap.get_type() not in ('VIDEO', 'CAMERA'):
raise RuntimeError("The input should be a video file or a numeric camera ID")
frames = RoiDetectorOnVideo(cap)
compute_embeddings_times = []
search_in_gallery_times = []
positions = []
frames_processed = 0
presenter = monitors.Presenter(args.utilization_monitors, 0)
video_writer = cv2.VideoWriter()
for image, view_frame in frames:
position = None
sorted_indexes = []
if image is not None:
image = central_crop(image, divide_by=5, shift=1)
elapsed, probe_embedding = time_elapsed(img_retrieval.compute_embedding, image)
compute_embeddings_times.append(elapsed)
elapsed, (sorted_indexes, distances) = time_elapsed(img_retrieval.search_in_gallery,
probe_embedding)
search_in_gallery_times.append(elapsed)
sorted_classes = [img_retrieval.gallery_classes[i] for i in sorted_indexes]
if args.ground_truth is not None:
position = sorted_classes.index(
img_retrieval.text_label_to_class_id[args.ground_truth])
positions.append(position)
log.info("ROI detected, found: %d, position of target: %d",
sorted_classes[0], position)
else:
log.info("ROI detected, found: %s", sorted_classes[0])
image, key = visualize(view_frame, position,
[img_retrieval.impaths[i] for i in sorted_indexes],
distances[sorted_indexes] if position is not None else None,
img_retrieval.input_size, np.mean(compute_embeddings_times),
np.mean(search_in_gallery_times), imshow_delay=3, presenter=presenter, no_show=args.no_show)
if frames_processed == 0:
if args.output and not video_writer.open(args.output, cv2.VideoWriter_fourcc(*'MJPG'),
cap.fps(), (image.shape[1], image.shape[0])):
raise RuntimeError("Can't open video writer")
frames_processed += 1
if video_writer.isOpened() and (args.output_limit <= 0 or frames_processed <= args.output_limit):
video_writer.write(image)
if key == 27:
break
print(presenter.reportMeans())
if positions:
compute_metrics(positions)
if __name__ == '__main__':
sys.exit(main() or 0)
命令:
# CPU
python3 image_retrieval_demo.py -m ./models/resnet50.xml -i ./image_test/01.jpg -g ./image_test/img_list.txt -d CPU
# GPU
python3 image_retrieval_demo.py -m ./models/resnet50.xml -i ./image_test/01.jpg -g ./image_test/img_list.txt -d GPU
# 神经计算棒
python3 image_retrieval_demo.py -m ./models/resnet50.xml -i ./image_test/01.jpg -g ./image_test/img_list.txt -d MYRIAD
img_list.txt格式如下:
01.jpg 0
02.jpg 0
03.jpg 1
标签:openvino,resnet50,help,--,onnx,image,args,add,retrieval 来源: https://blog.csdn.net/qq_35975447/article/details/118366778