其他分享
首页 > 其他分享> > mxnet makeloss

mxnet makeloss

作者:互联网

 

https://github.com/jacke121/Fairface-Recognition-Solution

https://github.com/paranoidai/Fairface-Recognition-Solution/blob/7f12bc4462cc765fe8d7a7fa820c63bfe2cc9121/train/pair_wise_loss.py

好几种loss函数:

if loss_type =='triplet':

 

 

# -*- coding=utf-8 -*-

import mxnet as mx
import numpy as np
import logging

logging.basicConfig(level=logging.INFO)

x = mx.sym.Variable('data')
y = mx.sym.FullyConnected(data=x, num_hidden=1)
label = mx.sym.Variable('label')
loss = mx.sym.MakeLoss(mx.sym.square(y - label))
pred_loss = mx.sym.Group([mx.sym.BlockGrad(y), loss])
ex = pred_loss.simple_bind(mx.cpu(), data=(32, 2))

# test
test_data = mx.nd.array(np.random.random(size=(32, 2)))
test_label = mx.nd.array(np.random.random(size=(32, 1)))

ex.forward(is_train=True, data=test_data, label=test_label)
ex.backward()

print ex.arg_dict
fc_w = ex.arg_dict['fullyconnected0_weight'].asnumpy()
fc_w_grad = ex.grad_arrays[1].asnumpy()
fc_bias = ex.arg_dict['fullyconnected0_bias'].asnumpy()
fc_bias_grad = ex.grad_arrays[2].asnumpy()

logging.info('fc_weight:{}, fc_weights_grad:{}'.format(fc_w, fc_w_grad))
logging.info('fc_bias:{}, fc_bias_grad:{}'.format(fc_bias, fc_bias_grad))

 

标签:mxnet,bias,sym,fc,ex,makeloss,grad,mx
来源: https://blog.csdn.net/jacke121/article/details/116867619