使用camera在tensorflow/slim下调用pb文件进行图像识别的预测
作者:互联网
建立demo_cam.py文件,python代码如下:
代码中的camera使用的是realsenseD435i
import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
#from IPython import display
#import pylab
#import PIL
import time
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import matplotlib.font_manager as fm
import pyrealsense2 as rs
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
profile = pipeline.start(config)
align_to = rs.stream.color
align = rs.align(align_to)
#image_dir='./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg'
dataset_dir='./data/flower_photos'
model_dir ='./output_model_pb/frozen_graph.pb'
def get_aligned_images():
frames = pipeline.wait_for_frames()
aligned_frames = align.process(frames)
aligned_depth_frame = aligned_frames.get_depth_frame()
color_frame = aligned_frames.get_color_frame()
depth_image = np.asanyarray(aligned_depth_frame.get_data())
depth_image_8bit = cv2.convertScaleAbs(depth_image, alpha=0.03)
depth_org = depth_image_8bit
depth_image_8bit = 255 - depth_image_8bit
pos=np.where(depth_image_8bit==255)
depth_image_8bit[pos]=0
depth_medianBlur = cv2.medianBlur(depth_image_8bit, 5) # 中值滤波
depth_max = np.max(depth_medianBlur)
#print(depth_max)
color_image = np.asanyarray(color_frame.get_data())
depth_image_3d = np.dstack((depth_image_8bit,depth_image_8bit,depth_image_8bit)) #depth image is 1 channel, color is 3 channels
depth_image_3d_org = np.dstack((depth_org, depth_org, depth_org))
#视差图
depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
return color_image, depth_medianBlur, depth_image_3d_org
#opencv
class TOD(object):
def __init__(self):
self.PATH_TO_CKPT = './output_model_pb/frozen_graph.pb'
self.NUM_CLASSES = 5
self.detection_graph = tf.Graph()
self.label_map = dataset_utils.read_label_file(dataset_dir)
with self.detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
#return detection_graph
with self.detection_graph.as_default():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.sess = tf.Session(graph=self.detection_graph, config=config)
self.windowNotSet = True
def visualization(self, image, str):
image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
draw = ImageDraw.Draw(image_pil)
font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15) # 设置字体DejaVu Sans
draw.text((10, 10), str, 'red', font) # 'fuchsia'
np.copyto(image, np.array(image_pil))
return image
def classify(self,image,resized):
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(resized, axis=0)
inp = self.detection_graph.get_tensor_by_name('input:0')
#predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Predictions/Reshape_1:0')
predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
start_time = time.time()
pred = self.sess.run(
predictions,
feed_dict={inp: image_np_expanded})
elapsed_time = time.time() - start_time
#print(pred)
print('inference time cost: {}'.format(elapsed_time))
font1 = str(self.label_map[pred.argmax()])
font2 = str(pred.max())
font3 = font1 + ":" + font2
img = self.visualization(image,font3)
#return pred
#print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max())
cv2.namedWindow("classification", cv2.WINDOW_NORMAL)
cv2.imshow("classification", img)
if __name__ == '__main__':
width = 299
height = 299
dim = (width, height)
# resize image to [-1,1] Maps pixel values to the range [-1, 1]
classifier = TOD()
while 1:
rgb, depth, depcol = get_aligned_images()
#image = cv2.imread(image_dir)
image = rgb
resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1
classifier.classify(image,resized)
k = cv2.waitKey(1) & 0xff
if k == ord('q') or k == 27:
pipeline.stop()
break
cv2.destroyAllWindows()
其中用到的labels.txt文件的格式为:
0:daisy
1:dandelion
2:roses
3:sunflowers
4:tulips
运行
python demo_cam.py
标签:图像识别,graph,image,slim,pb,depth,np,import,self 来源: https://blog.csdn.net/gaoqing_dream163/article/details/115205714