CenterNet pytorch 转 libtorch模型并使用
作者:互联网
使用原版的github上的centerNet 生成模型,这部分参考我的另外一篇博文:
https://blog.csdn.net/qq_31610789/article/details/99938631
c++后,需要用到libtorch库,按照官方教程编译即可,CMakeList.txt如下:
cmake_minimum_required(VERSION 3.13)
project(CenterNetCppPro)
set(CMAKE_CXX_STANDARD 14)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/opencv-3.2.0/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/cuda/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/include/torch/csrc/api/include)
link_directories(${CMAKE_SOURCE_DIR}/3rdParty/opencv-3.2.0/lib)
link_directories(${CMAKE_SOURCE_DIR}/3rdParty/cuda/lib)
link_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/lib)
#add_library()
add_executable(CenterNetCppPro main.cpp plateDangerousDetection/plateDangerousDetection.cpp plateDangerousDetection/include/plateDangerousDetection.hpp)
target_link_libraries(CenterNetCppPro opencv_core opencv_imgproc opencv_imgcodecs opencv_highgui caffe2 caffe2_gpu c10_cuda c10 torch cuda /usr/local/cuda/lib64/libnvrtc.so)
这里基本使用,
主函数:
处理函数,包括保存热力图:并且验证了是否和pytorch版本一致的代码。
int PlateDangerousDetection::process(cv::Mat& img, std::vector<PlateDangerousOutput>& output) {
std::vector<torch::jit::IValue> inputs; //def an input
cv::Mat img1 = cv::imread(
"/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/130642000000025553+冀FK8196+134414+130642000000025553+01+0+@1873077@@@2019-12-18#23#29#50+a1+0_0.jpg");
cv::Mat image, float_image;
std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[0]) << std::endl;
std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[1]) << std::endl;
std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[2]) << std::endl;
resize(img1, image, cv::Size(imageHeight_, imageWidth_), cv::INTER_LINEAR); // resize 图像
std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[0]) << std::endl;
std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[1]) << std::endl;
std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[2]) << std::endl;
// cvtColor(image, image, CV_BGR2RGB); // bgr -> rgb
image.convertTo(float_image, CV_32F, 1.0 / 255); //归一化到[0,1]区间 TODO
float *point_img;
std::cout << float_image.at<cv::Vec3f>(0, 0)[0] << std::endl;
std::cout << float_image.at<cv::Vec3f>(0, 0)[1] << std::endl;
std::cout << float_image.at<cv::Vec3f>(0, 0)[2] << std::endl;
// point_img = float_image.ptr(32);
// std::cout << *(float_image.data) << std::endl; //输出一个像素点点值
auto img_tensor = torch::CPU(torch::kFloat32).tensorFromBlob(float_image.data, {1, imageHeight_, imageWidth_,
3}); //将cv::Mat转成tensor,大小为1,224,224,3
img_tensor = img_tensor.permute({0, 3, 1, 2}); //调换顺序变为torch输入的格式 1,3,224,224
//img_tensor[0][0] = img_tensor[0][0].sub_(0.485).div_(0.229); //减去均值,除以标准差
//img_tensor[0][1] = img_tensor[0][1].sub_(0.456).div_(0.224);
//img_tensor[0][2] = img_tensor[0][2].sub_(0.406).div_(0.225);
auto img_var = torch::autograd::make_variable(img_tensor, false); //不需要梯度
inputs.emplace_back(img_var.to(at::kCUDA)); // 把预处理后的图像放入gpu
torch::Tensor result = centerNet->forward(inputs).toTensor(); //前向传播获取结果
inputs.pop_back();
std::cout << "result.sizes() = " << result.sizes() << std::endl;
std::cout << "Forward over!!!" << std::endl;
for (int i = 0; i < 8; ++i) {
std::cout << "result:" + std::to_string(i) + " " << result[0][i][0][0] << std::endl;
}
// result (1, 8, 128, 128) (0, 0-3, 128, 128) pythorch-'hm'(but hasn't been sigmoided.) /(0, 4-5, 128, 128) pythorch-'wh' / (0, 6-7, 128, 128) pythorch-'reg'
torch::Tensor hm = at::select(result, 1, 2);
std::cout << "hm.sizes() = " << hm.sizes() << std::endl;
int splitList[] = {4, 2, 2};
std::vector<torch::Tensor> splitResult = torch::split_with_sizes(result, {4, 2, 2}, 1);
for (auto tt:splitResult) {
std::cout << "tt.sizes() = " << tt.sizes() << std::endl;
}
for (int i = 0; i < 8; ++i) {
std::cout << "result:" + std::to_string(i) + " " << result[0][i][0][0] << std::endl;
}
std::vector<torch::Tensor> splitHeatMap = torch::split_with_sizes(splitResult[0], {3, 1}, 1);
torch::Tensor heatMapTensor012 = splitHeatMap[0];
heatMapTensor012 = heatMapTensor012.reshape({3, 128, 128});
heatMapTensor012 = heatMapTensor012.squeeze().detach().permute({1, 2, 0});
std::cout << "heatMapTensor012.sizes() = " << heatMapTensor012.sizes() << std::endl;
torch::Tensor heatMapTensor012Img = heatMapTensor012.add(5).mul(255).clamp(0, 255).to(torch::kU8);
// heatMapTensor012.sum(1)
for (int i = 0; i < 3; ++i) {
std::cout << "heatMapTensor012Img:" + std::to_string(i) + " " << heatMapTensor012Img[i][0][0] << std::endl;
}
heatMapTensor012Img = heatMapTensor012Img.to(torch::kCPU);
cv::Mat resultImg(128, 128, CV_8UC3);
std::memcpy((void *) resultImg.data, heatMapTensor012Img.data_ptr(),
sizeof(torch::kU8) * heatMapTensor012Img.numel());
cv::Mat resizedResultImg;
cv::resize(resultImg, resizedResultImg, img1.size());
int r, g, b;
for (int row = 0; row < resizedResultImg.rows; ++row) {
uchar *data0 = img1.ptr<uchar>(row);
uchar *data1 = resizedResultImg.ptr<uchar>(row);
for (int col = 0; col < resizedResultImg.cols; ++col) {
// ---------【开始处理每个像素】-------------
data0[col * 3] = data0[col * 3] * 0.3;
data0[col * 3+1] = data0[col * 3+1] * 0.3;
data0[col * 3+2] = data0[col * 3+2] * 0.3;
int temp = 0;
temp = data0[col * 3] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2])/3.1;
data0[col*3] = temp > 255? 255: temp;
temp = data0[col * 3+1] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2])/3.1;
data0[col*3+1] = temp > 255? 255: temp;
temp = data0[col*3+2] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2]);
data0[col*3+2] = temp > 255? 255: temp;
}
}
cv::imwrite("/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/centerNet.jpg", img1);
cv::imshow("heapmap012", img1);
cv::waitKey();
bool centerNetDecoder(torch::Tensor &OutPutTensor, std::vector<PlateDangerousOutput> &output);
}
// 初始化等函数
PlateDangerousDetection::PlateDangerousDetection() {
}
PlateDangerousDetection::~PlateDangerousDetection()
{
}
PlateDangerousDetection &PlateDangerousDetection::ins() {
static thread_local PlateDangerousDetection obj;
return obj;
}
int PlateDangerousDetection::init(const std::string& configPath) {
torch::NoGradGuard no_grad;
centerNet = torch::jit::load(configPath+"/plateDangerousDetection/torch_model.pt");
centerNet->to(at::kCUDA);
assert(centerNet != nullptr);
std::cout << "[INFO] init model done...\n";
return 0;
}
主函数如下:一些非关键代码没有贴上来。
int main() {
std::cout << "Hello, World!" << std::endl;
int flag = PlateDangerousDetection::ins().init("/data_1/vir/weixianpinche/train_ws/CenterNet/c++");
if (flag != 0) {
std::cout << "VIRPlateRecognition init faild" << std::endl;
return flag;
}
std::vector<std::vector<PlateDangerousOutput>> plateDangerousOutputs;
std::vector<PlateDangerousOutput> plateDangerousOutput;
cv::Mat carRect = cv::imread("/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/130642000000025553+冀FK8196+134414+130642000000025553+01+0+@1873077@@@2019-12-18#23#29#50+a1+0_0.jpg");
PlateDangerousDetection::ins().process(carRect, plateDangerousOutput);
return 0;
}
下面是一些后处理(事实上后处理是比较复杂的部分,等待移植更新,可以使用c++版本numpy移植)
AaronJiang395 发布了38 篇原创文章 · 获赞 8 · 访问量 1万+ 私信 关注标签:std,temp,CenterNet,col,libtorch,pytorch,PlateDangerousDetection,include,data0 来源: https://blog.csdn.net/qq_31610789/article/details/104063943