其他分享
首页 > 其他分享> > ECM技术学习:模板匹配(Template matching)

ECM技术学习:模板匹配(Template matching)

作者:互联网

模板匹配(Template matching, TM)是一种解码端推导方法,用来细化当前CU的运动信息,使得当前CU的MV更准确。

TM主要是通过寻找一个MV使得当前图片的模板(当前 CU 的顶部和/或左侧相邻块)和参考图片的模板之间的匹配误差最小。如下图所示,在 [– 8, +8] 像素搜索范围内围绕当前 CU 的初始 MV 搜索更好的 MV。其中TM是基于 AMVR 模式确定搜索步长,并且 TM 可以在Merge模式下与双边匹配(bilateral matching, BM)过程级联。

 在AMVP模式下,仅对特定的MV候选项进行细化,具体地,根据模板匹配误差确定进行TM细化的MVP候选项:选取当前块模板与参考块模板差异最小的MVP候选项进行TM细化。 TM 通过使用迭代菱形搜索,从 [–8, +8] 像素搜索范围内的全像素 MVD 精度(或AMVR 模式下的 4 像素)开始优化此 MVP 候选。 可以通过使用具有全像素 MVD 精度(或AMVR 模式下的 4 像素)的交叉搜索来进一步细化 AMVP 候选,然后根据表 1 中指定的 AMVR 模式依次进行半像素和四分之一像素搜索。 这个搜索过程确保MVP候选在TM过程之后仍然保持与AMVR模式所指示的相同的MV精度。

 在Merge模式下,对Merge索引所指示的Merge候选者应用相似的搜索方法。 如表 1 所示,TM 可以一直执行到 1/8 像素 MVD 精度或跳过那些超过半像素 MVD 精度的,这取决于是否根据Merge的运动信息使用替代插值滤波器(alternative interpolation filter, AltIF,当 AMVR 处于半像素模式时使用) 。 此外,当启用 TM 模式时,模板匹配可以作为基于块和基于子块的双边匹配 (BM) 方法之间的独立过程或额外的 MV 细化过程,这取决于BM是否可以根据其启用条件检查启用。

相关代码

ECM中,TM细化MV的入口函数是deriveTMMv函数,需要注意的是Merge模式下对全部的MV候选项都会进行TM细化,而AMVP模式下仅对候选列表中的模板匹配误差最小的MV进行TM细化,二者调用的函数不同,如下所示:

#if TM_MRG
  // Merge模式下调用的函数
  void       deriveTMMv         (PredictionUnit& pu);
#endif
  // 对特定的MV进行细化
  Distortion deriveTMMv         (const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf = nullptr);

 这两个函数的代码及注释如下所示:

void InterPrediction::deriveTMMv(PredictionUnit& pu)
{
  if( !pu.tmMergeFlag )
  {
    return;
  }

  Distortion minCostUni[NUM_REF_PIC_LIST_01] = { std::numeric_limits<Distortion>::max(), std::numeric_limits<Distortion>::max() };

  for (int iRefList = 0; iRefList < ( pu.cu->slice->isInterB() ? NUM_REF_PIC_LIST_01 : 1 ) ; ++iRefList)
  {
    if (pu.interDir & (iRefList + 1))
    {
      minCostUni[iRefList] = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), (RefPicList)iRefList, pu.refIdx[iRefList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[iRefList]);
    }
  }

  if (pu.cu->slice->isInterB() && pu.interDir == 3
#if MULTI_PASS_DMVR
    && !PU::checkBDMVRCondition(pu)
#endif
    )
  {
    if (minCostUni[0] == std::numeric_limits<Distortion>::max() || minCostUni[1] == std::numeric_limits<Distortion>::max())
    {
      return;
    }

    RefPicList eTargetPicList = (minCostUni[0] <= minCostUni[1]) ? REF_PIC_LIST_1 : REF_PIC_LIST_0;
    MvField    mvfBetterUni(pu.mv[1 - eTargetPicList], pu.refIdx[1 - eTargetPicList]);
    Distortion minCostBi = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), eTargetPicList, pu.refIdx[eTargetPicList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[eTargetPicList], &mvfBetterUni);

    if (minCostBi > (minCostUni[1 - eTargetPicList] + (minCostUni[1 - eTargetPicList] >> 3)))
    {
      pu.interDir = 1 + (1 - eTargetPicList);
      pu.mv    [eTargetPicList] = Mv();
      pu.refIdx[eTargetPicList] = NOT_VALID;
    }
  }
}

 

#if TM_AMVP || TM_MRG
// maxSearchRounds 最大搜索次数,为0时表示不进行搜索,仅计算初始MV对应的模板的Cost
Distortion InterPrediction::deriveTMMv(const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf)
{
  CHECK(refIdx < 0, "Invalid reference index for TM");
  const CodingUnit& cu   = *pu.cu;
  const Picture& refPic  = *cu.slice->getRefPic(eRefList, refIdx)->unscaledPic;
  bool doSimilarityCheck = otherMvf == nullptr ? false : cu.slice->getRefPOC((RefPicList)eRefList, refIdx) == cu.slice->getRefPOC((RefPicList)(1 - eRefList), otherMvf->refIdx);

  InterPredResources interRes(m_pcReshape, m_pcRdCost, m_if, m_filteredBlockTmp[0][COMPONENT_Y]
                           ,  m_filteredBlock[3][1][0], m_filteredBlock[3][0][0]
  );
  // 构造函数,获取当前模板和参考模板
  TplMatchingCtrl tplCtrl(pu, interRes, refPic, fillCurTpl, COMPONENT_Y, true, maxSearchRounds, m_pcCurTplAbove, m_pcCurTplLeft, m_pcRefTplAbove, m_pcRefTplLeft, mv, (doSimilarityCheck ? &(otherMvf->mv) : nullptr), curBestCost);
  
  if (!tplCtrl.getTemplatePresentFlag()) 
  {
    // 如果上模板和左模板都不存在
    return std::numeric_limits<Distortion>::max();
  }

  if (otherMvf == nullptr) // uni prediction 单向预测
  {
    tplCtrl.deriveMvUni<TM_TPL_SIZE>();
    mv = tplCtrl.getFinalMv(); // 返回最终细化的MV
    return tplCtrl.getMinCost(); // 返回最小的代价
  }
  else // bi prediction 双向预测
  {
    const Picture& otherRefPic = *cu.slice->getRefPic((RefPicList)(1-eRefList), otherMvf->refIdx)->unscaledPic; // 另一个方向的参考帧
    // 当前模板减去另一个方向的参考模板
    tplCtrl.removeHighFreq<TM_TPL_SIZE>(otherRefPic, otherMvf->mv, getBcwWeight(cu.BcwIdx, eRefList));
    tplCtrl.deriveMvUni<TM_TPL_SIZE>();
    mv = tplCtrl.getFinalMv();

    int8_t intWeight = getBcwWeight(cu.BcwIdx, eRefList);
    return (tplCtrl.getMinCost() * intWeight + (g_BcwWeightBase >> 1)) >> g_BcwWeightBase;
  }
}

TM过程中模板的获取以及搜索过程都是通过TplMatchingCtrl类控制的,代码如下所示:

class TplMatchingCtrl
{
  enum TMSearchMethod
  {
    TMSEARCH_DIAMOND,
    TMSEARCH_CROSS,
    TMSEARCH_NUMBER_OF_METHODS
  };

  const CodingUnit&         m_cu;
  const PredictionUnit&     m_pu;
        InterPredResources& m_interRes;

  const Picture&    m_refPic;
  const Mv          m_mvStart;
        Mv          m_mvFinal;
  const Mv*         m_otherRefListMv;
        Distortion  m_minCost;
        bool        m_useWeight;
        int         m_maxSearchRounds;
        ComponentID m_compID;

  PelBuf m_curTplAbove;
  PelBuf m_curTplLeft;
  PelBuf m_refTplAbove;
  PelBuf m_refTplLeft;
  PelBuf m_refSrAbove; // pre-filled samples on search area
  PelBuf m_refSrLeft;  // pre-filled samples on search area

#if JVET_X0056_DMVD_EARLY_TERMINATION
  Distortion m_earlyTerminateTh;
#endif
#if MULTI_PASS_DMVR
  Distortion m_tmCostArrayDiamond[9];
  Distortion m_tmCostArrayCross[5];
#endif

public:
  // 构造函数,获取当前模板和参考模板
  TplMatchingCtrl(const PredictionUnit&     pu,
                        InterPredResources& interRes, // Bridge required resource from InterPrediction
                  const Picture&            refPic,
                  const bool                fillCurTpl,
                  const ComponentID         compID,
                  const bool                useWeight,
                  const int                 maxSearchRounds,
                        Pel*                curTplAbove,
                        Pel*                curTplLeft,
                        Pel*                refTplAbove,
                        Pel*                refTplLeft,
                  const Mv&                 mvStart,
                  const Mv*                 otherRefListMv,
                  const Distortion          curBestCost
  );
  // 返回模板是否存在
  bool       getTemplatePresentFlag() { return m_curTplAbove.buf != nullptr || m_curTplLeft.buf != nullptr; }
  Distortion getMinCost            () { return m_minCost; } // 返回最小的cost
  Mv         getFinalMv            () { return m_mvFinal; } // 返回最终细化后的MV
  static int getDeltaMean          (const PelBuf& bufCur, const PelBuf& bufRef, const int rowSubShift, const int bd);

  template <int tplSize> void deriveMvUni    (); // 推导单向MV
  template <int tplSize> void removeHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);

private:
  template <int tplSize, bool TrueA_FalseL>         bool       xFillCurTemplate   (Pel* tpl);
  template <int tplSize, bool TrueA_FalseL, int sr> PelBuf     xGetRefTemplate    (const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf);
  template <int tplSize, bool TrueA_FalseL>         void       xRemoveHighFreq    (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight);
  template <int tplSize, int searchPattern>         void       xRefineMvSearch    (int maxSearchRounds, int searchStepShift);
#if MULTI_PASS_DMVR
  template <int searchPattern>                      void       xNextTmCostAarray  (int bestDirect);
  template <int searchPattern>                      void       xDeriveCostBasedMv ();
  template <bool TrueX_FalseY>                      void       xDeriveCostBasedOffset (Distortion costLorA, Distortion costCenter, Distortion costRorB, int log2StepSize);
                                                    int        xBinaryDivision    (int64_t numerator, int64_t denominator, int fracBits);
#endif
  template <int tplSize>                            Distortion xGetTempMatchError (const Mv& mv);
  template <int tplSize, bool TrueA_FalseL>         Distortion xGetTempMatchError (const Mv& mv);
};

 TM模式中的当前模板的获取和参考模板的获取是在TplMatchingCtrl类的构造函数中实现的,分别调用xFillCurTemplate函数和xGetRefTemplate函数实现当前模板的获取和参考模板的获取。

#if TM_AMVP || TM_MRG
TplMatchingCtrl::TplMatchingCtrl( const PredictionUnit&     pu,
                                        InterPredResources& interRes,
                                  const Picture&            refPic,
                                  const bool                fillCurTpl,
                                  const ComponentID         compID,
                                  const bool                useWeight,
                                  const int                 maxSearchRounds,
                                        Pel*                curTplAbove,
                                        Pel*                curTplLeft,
                                        Pel*                refTplAbove,
                                        Pel*                refTplLeft,
                                  const Mv&                 mvStart,
                                  const Mv*                 otherRefListMv,
                                  const Distortion          curBestCost
)
: m_cu              (*pu.cu)
, m_pu              (pu)
, m_interRes        (interRes)
, m_refPic          (refPic)
, m_mvStart         (mvStart)
, m_mvFinal         (mvStart)
, m_otherRefListMv  (otherRefListMv)
, m_minCost         (curBestCost)
, m_useWeight       (useWeight)
, m_maxSearchRounds (maxSearchRounds)
, m_compID          (compID)
{
  // Initialization 初始化
  // 填充当前模板
  const bool tplAvalableAbove = xFillCurTemplate<TM_TPL_SIZE, true >((fillCurTpl ? curTplAbove : nullptr)); // 上侧模板可用
  const bool tplAvalableLeft  = xFillCurTemplate<TM_TPL_SIZE, false>((fillCurTpl ? curTplLeft  : nullptr)); // 左侧模板可用
  m_curTplAbove = tplAvalableAbove ? PelBuf(curTplAbove, pu.lwidth(),   TM_TPL_SIZE ) : PelBuf();
  m_curTplLeft  = tplAvalableLeft  ? PelBuf(curTplLeft , TM_TPL_SIZE,   pu.lheight()) : PelBuf();
  // 参考模板
  m_refTplAbove = tplAvalableAbove ? PelBuf(refTplAbove, m_curTplAbove              ) : PelBuf();
  m_refTplLeft  = tplAvalableLeft  ? PelBuf(refTplLeft , m_curTplLeft               ) : PelBuf();
#if JVET_X0056_DMVD_EARLY_TERMINATION
  m_earlyTerminateTh = TM_TPL_SIZE * ((tplAvalableAbove ? m_pu.lwidth() : 0) + (tplAvalableLeft ? m_pu.lheight() : 0));
#endif

  // Pre-interpolate samples on search area 在搜索区域预插样本
  // 上参考模板以及其相邻长度为 8 的搜索范围
  m_refSrAbove = tplAvalableAbove && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufA, m_curTplAbove.width + 2 * TM_SEARCH_RANGE, m_curTplAbove.height + 2 * TM_SEARCH_RANGE) : PelBuf();
  if (m_refSrAbove.buf != nullptr)
  {
    m_refSrAbove = xGetRefTemplate<TM_TPL_SIZE, true, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrAbove);
    m_refSrAbove = m_refSrAbove.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplAbove); // 定位到搜索参考模板的初始位置
  }
  // 左参考模板
  m_refSrLeft  = tplAvalableLeft  && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufL, m_curTplLeft .width + 2 * TM_SEARCH_RANGE, m_curTplLeft .height + 2 * TM_SEARCH_RANGE) : PelBuf();
  if (m_refSrLeft.buf != nullptr)
  {
    m_refSrLeft = xGetRefTemplate<TM_TPL_SIZE, false, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrLeft);
    m_refSrLeft = m_refSrLeft.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplLeft);
  }
}

xFillCurTemplate函数获取当前模板: 

template <int tplSize, bool TrueA_FalseL>
bool TplMatchingCtrl::xFillCurTemplate(Pel* tpl)
{
  const Position          posOffset = TrueA_FalseL ? Position(0, -tplSize) : Position(-tplSize, 0); // 位置偏移
  // 相邻CU
  const CodingUnit* const cuNeigh   = m_cu.cs->getCU(m_pu.blocks[m_compID].pos().offset(posOffset), toChannelType(m_compID));

  if (cuNeigh == nullptr) // 相邻CU不可用,直接返回FALSE
  {
    return false;
  }

  if (tpl == nullptr) // 存储模板的指针为空,返回
  {
    return true;
  }

  const Picture&          currPic = *m_cu.cs->picture; // 当前帧
  const CPelBuf           recBuf  = currPic.getRecoBuf(m_cu.cs->picture->blocks[m_compID]); // 当前帧的重建分量
        std::vector<Pel>& invLUT  = m_interRes.m_pcReshape->getInvLUT();
  const bool              useLUT  = isLuma(m_compID) && m_cu.cs->picHeader->getLmcsEnabledFlag() && m_interRes.m_pcReshape->getCTUFlag();
#if JVET_W0097_GPM_MMVD_TM & TM_MRG
  if (m_cu.geoFlag)
  {
    CHECK(m_pu.geoTmType == GEO_TM_OFF, "invalid geo template type value");
    if (m_pu.geoTmType == GEO_TM_SHAPE_A)
    {
      if (TrueA_FalseL == 0)
      {
        return false;
      }
    }
    if (m_pu.geoTmType == GEO_TM_SHAPE_L)
    {
      if (TrueA_FalseL == 1)
      {
        return false;
      }
    }
  }
#endif
  const Size dstSize = (TrueA_FalseL ? Size(m_pu.lwidth(), tplSize) : Size(tplSize, m_pu.lheight()));
  for (int h = 0; h < (int)dstSize.height; h++)
  {
    const Position recPos = TrueA_FalseL ? Position(0, -tplSize + h) : Position(-tplSize, h);
    const Pel*     rec    = recBuf.bufAt(m_pu.blocks[m_compID].pos().offset(recPos));
          Pel*     dst    = tpl + h * dstSize.width;

    for (int w = 0; w < (int)dstSize.width; w++)
    {
      int recVal = rec[w];
      dst[w] = useLUT ? invLUT[recVal] : recVal;
    }
  }

  return true;
}

xGetRefTemplate函数获取参考模板: 

template <int tplSize, bool TrueA_FalseL, int sr>
PelBuf TplMatchingCtrl::xGetRefTemplate(const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf)
{
  // read from pre-interpolated buffer 从预插值缓冲区读取
  PelBuf& refSrBuf = TrueA_FalseL ? m_refSrAbove : m_refSrLeft;
  // sr = 0 直接从预插值的缓冲区读取样本
  if (sr == 0 && refPic.getPOC() == m_refPic.getPOC() && refSrBuf.buf != nullptr)
  {
    Mv mvDiff = _mv - m_mvStart;
    if ((mvDiff.getAbsHor() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0 && (mvDiff.getAbsVer() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0)
    {
      mvDiff >>= MV_FRACTIONAL_BITS_INTERNAL;
      if (mvDiff.getAbsHor() <= TM_SEARCH_RANGE && mvDiff.getAbsVer() <= TM_SEARCH_RANGE)
      {
        return refSrBuf.subBuf(Position(mvDiff.getHor(), mvDiff.getVer()), dstBuf);
      }
    }
  }

  // Do interpolation on the fly 插值
  Position blkPos  = ( TrueA_FalseL ? Position(curPu.lx(), curPu.ly() - tplSize) : Position(curPu.lx() - tplSize, curPu.ly()) );
  Size     blkSize = Size(dstBuf.width, dstBuf.height);
  Mv       mv      = _mv - Mv(sr << MV_FRACTIONAL_BITS_INTERNAL, sr << MV_FRACTIONAL_BITS_INTERNAL);
  clipMv( mv, blkPos, blkSize, *m_cu.cs->sps, *m_cu.cs->pps );

  const int lumaShift = 2 + MV_FRACTIONAL_BITS_DIFF;
  const int horShift  = (lumaShift + ::getComponentScaleX(m_compID, m_cu.chromaFormat));
  const int verShift  = (lumaShift + ::getComponentScaleY(m_compID, m_cu.chromaFormat));

  const int xInt  = mv.getHor() >> horShift;
  const int yInt  = mv.getVer() >> verShift;
  const int xFrac = mv.getHor() & ((1 << horShift) - 1);
  const int yFrac = mv.getVer() & ((1 << verShift) - 1);

  const CPelBuf refBuf = refPic.getRecoBuf(refPic.blocks[m_compID]);
  const Pel* ref       = refBuf.bufAt(blkPos.offset(xInt, yInt));
        Pel* dst       = dstBuf.buf;
        int  refStride = refBuf.stride;
        int  dstStride = dstBuf.stride;
        int  bw        = (int)blkSize.width;
        int  bh        = (int)blkSize.height;

  const int  nFilterIdx   = 1;
  const bool useAltHpelIf = false;
  const bool biMCForDMVR  = false;

  if ( yFrac == 0 )
  {
    m_interRes.m_if.filterHor( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, xFrac, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
  }
  else if ( xFrac == 0 )
  {
    m_interRes.m_if.filterVer( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, yFrac, true, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
  }
  else
  {
    const int vFilterSize = isLuma(m_compID) ? NTAPS_BILINEAR : NTAPS_CHROMA;
    PelBuf tmpBuf = PelBuf(m_interRes.m_ifBuf, Size(bw, bh+vFilterSize-1));

    m_interRes.m_if.filterHor( m_compID, (Pel*)ref - ((vFilterSize>>1) -1)*refStride, refStride, tmpBuf.buf, tmpBuf.stride, bw, bh+vFilterSize-1, xFrac, false, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
    JVET_J0090_SET_CACHE_ENABLE( false );
    m_interRes.m_if.filterVer( m_compID, tmpBuf.buf + ((vFilterSize>>1) -1)*tmpBuf.stride, tmpBuf.stride, dst, dstStride, bw, bh, yFrac, false, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf );
    JVET_J0090_SET_CACHE_ENABLE( true );
  }

  return dstBuf;
}

 在deriveMvUni函数中进行单向MV的细化:

template <int tplSize>
void TplMatchingCtrl::deriveMvUni()
{
  if (m_minCost == std::numeric_limits<Distortion>::max())
  {
    m_minCost = xGetTempMatchError<tplSize>(m_mvStart); // 计算初始位置处模板的Cost
  }

  if (m_maxSearchRounds <= 0)
  {
    return;
  }
  // 搜索步长
  int searchStepShift = (m_cu.imv == IMV_4PEL ? MV_FRACTIONAL_BITS_INTERNAL + 2 : MV_FRACTIONAL_BITS_INTERNAL);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_DIAMOND>(m_maxSearchRounds, searchStepShift);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 1);
#if MULTI_PASS_DMVR
  if (!m_pu.bdmvrRefine)
  {
#endif
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 2);
  xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS  >(                1, searchStepShift - 3);
#if MULTI_PASS_DMVR
  }
  else
  {
    xDeriveCostBasedMv<TplMatchingCtrl::TMSEARCH_CROSS>();
  }
#endif
}

标签:const,pu,int,ECM,PelBuf,TM,Template,matching,模板
来源: https://blog.csdn.net/BigDream123/article/details/122030869