其他分享
首页 > 其他分享> > MindSpore网络自定义反向报错:TypeError: The params of function 'bprop' of

MindSpore网络自定义反向报错:TypeError: The params of function 'bprop' of

作者:互联网

1. 报错描述

1.1 系统环境

Hardware Environment(Ascend/GPU/CPU): GPU
Software Environment:

1.2 基本信息

1.2.1 源码

import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C

grad_all = C.GradOperation(get_all=True)

class MulAdd(nn.Cell):
    def construct(self, x, y):
        return 2 * x + y

    def bprop(self, x, y, out):
        return 2 * x, 2 * y
mul_add = MulAdd()
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
output = grad_all(mul_add)(x, y)

1.2.2 报错

TypeError: The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' and 'dout'

Traceback (most recent call last):
  File "test_grad.py", line 20, in <module>
    output = grad_all(mul_add)(x, y)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 522, in staging_specialize
    out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 93, in wrapper
    results = fn(*arg, **kwargs)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 353, in __call__
    phase = self.compile(args_list, self.fn.__name__)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 321, in compile
    is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
TypeError: The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' and 'dout'.
In file test_grad.py(13)
    def bprop(self, x, y, out):
    ^

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------

# In file test_grad.py(13)
    def bprop(self, x, y, out):
    ^

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc:651 BuildOutput

2. 原因分析与解决方法

在这个用例中, 我们使用了Cell的自定义反向规则。 而报错信息也提示了我们是自定义规则的输入, 即

def bprop(self, x, y, out):

这句话存在错误。 

在自定义Cell的反向规则bprop时, 需要接受三类输入, 分别是Cell的正向输入(在本用例中为x, y), Cell的正向输出(在本用例中为out),以及输入网络反向的累加梯度(dout)。本用例中正式因为缺少了dout输入, 因此运行失败。 因此我们只需要将代码更改为:

def bprop(self, x, y, out, dout):
    return 2 * x, 2 * y

 程序即可正常运行。

下图表示了三类输入分别的意义, dout为反向图前一个节点输出的梯度, bprop函数需要此输入来对计算的梯度进行继承与使用。

Untitled Diagram.png

另外, bprop的三类输入是构图的时候需要使用的, 因此即使某些输入在bprop函数中没有被使用, 也是需要传入bprop中的。

3. 参考文档

https://www.mindspore.cn/tutorials/zh-CN/master/advanced/network/derivation.html#%E8%87%AA%E5%AE%9A%E4%B9%89%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E5%87%BD%E6%95%B0

标签:function,自定义,bprop,self,py,Cell,报错,mindspore,out
来源: https://www.cnblogs.com/skytier/p/16485291.html