编程语言
首页 > 编程语言> > Faster RCNN 推理 从头写 java (五) Classifier网络输出对 ROIs过滤与修正

Faster RCNN 推理 从头写 java (五) Classifier网络输出对 ROIs过滤与修正

作者:互联网

一: 输入输出

输入:

输出:

二: 流程

三: code by code

将classifier 网络的输出P_cls, P_regr 转换为Nd4j INDArray

INDArray P_cls = TypeConvertor.tensorToNDArray_s3(classifier_output.getCls());
INDArray P_regr = TypeConvertor.tensorToNDArray_s3(classifier_output.getRegression());

遍历每个ROIs,并切过滤掉背景和概率小于0.8 的ROI

for (int ii = 0; ii < P_cls.shape()[1]; ii++)
{
    INDArray theCls = P_cls.get(NDArrayIndex.point(0), NDArrayIndex.point(ii), NDArrayIndex.all());

    // 是否背景.
    boolean isBackground = theCls.argMax(0).getInt(0) == (P_cls.shape()[2] - 1);

    // 小于0.8概率 || 是背景
    if (theCls.maxNumber().floatValue() < bbox_threshold || isBackground)
    {
        continue;
    }

提取对应的ROI坐标和回归值

int x = ROIs.getInt(0, ii, 0);
int y = ROIs.getInt(0, ii, 1);
int w = ROIs.getInt(0, ii, 2);
int h = ROIs.getInt(0, ii, 3);

float tx = P_regr.getFloat(new int[]{0, ii, 0});
float ty = P_regr.getFloat(new int[]{0, ii, 1});
float tw = P_regr.getFloat(new int[]{0, ii, 2});
float th = P_regr.getFloat(new int[]{0, ii, 3});

使用Classifier的输出P_regr 来修正ROI, 这个算法逻辑与训练时生成Classifier 网络标注算法相反.

tx /= classifier_regr_std[0];
ty /= classifier_regr_std[1];
tw /= classifier_regr_std[2];
th /= classifier_regr_std[3];

int[] coor_out;

try {
    // [x, y, w, h] 格式.
    // coor_out[0]: x
    // coor_out[1]: y
    // coor_out[2]: w
    // coor_out[3]: h
    coor_out = apply_regr(x, y, w, h, tx, ty, tw, th);
}
catch (Exception e)
{
    continue;
}

将修正后的ROI坐标转换到VGG16的feature map 维度上,rpn_stride = 16
并将坐标从 [x1, y1, x2, y2] 转为 [x, y, w, h]

float x1 = coor_out[0] * rpn_stride;
float y1 = coor_out[1] * rpn_stride;
float x2 = (coor_out[0] + coor_out[2]) * rpn_stride;
float y2 = (coor_out[1] + coor_out[3]) * rpn_stride;

float[] bbox = new float[] {
        coor_out[0] * rpn_stride,
        coor_out[1] * rpn_stride,
        (coor_out[0] + coor_out[2]) * rpn_stride,
        (coor_out[1] + coor_out[3]) * rpn_stride
};

纵向排列一下,将N个bboxes的数组构建成INDArray shape = [N, 4]
修正为NMS(非最大值抑制)的数据输入格式. 下一个流程就需要执行NMS了。

INDArray candidate_bboxes = Nd4j.vstack(bboxes);
INDArray candidate_probs = Nd4j.create(probs);

标签:coor,java,Faster,int,float,regr,ii,ROIs,out
来源: https://blog.csdn.net/tabsong_coke/article/details/94052695