激活函数之tanh介绍及C++/PyTorch实现
作者:互联网
深度神经网络中使用的激活函数有很多种,这里介绍下tanh。它的公式如下,截图来自于维基百科(https://en.wikipedia.org/wiki/Activation_function):
tanh又称双曲正切,它解决了sigmoid非零中心问题。tanh取值范围在(-1, 1)内,它也是非线性的。它也不能完全解决梯度消失问题。
C++实现如下:
template<typename _Tp>
int activation_function_tanh(const _Tp* src, _Tp* dst, int length)
{
for (int i = 0; i < length; ++i) {
_Tp ep = std::exp(src[i]);
_Tp em = std::exp(-src[i]);
dst[i] = (ep - em) / (ep + em);
}
return 0;
}
template<typename _Tp>
int activation_function_tanh_derivative(const _Tp* src, _Tp* dst, int length)
{
for (int i = 0; i < length; ++i) {
dst[i] = (_Tp)1. - src[i] * src[i];
}
return 0;
}
int test_activation_function()
{
std::vector<float> src{ 1.1f, -2.2f, 3.3f, 0.4f, -0.5f, -1.6f };
int length = src.size();
std::vector<float> dst(length);
fprintf(stderr, "source vector: \n");
fbc::print_matrix(src);
fprintf(stderr, "calculate activation function:\n");
fprintf(stderr, "type: tanh result: \n");
fbc::activation_function_tanh(src.data(), dst.data(), length);
fbc::print_matrix(dst);
fprintf(stderr, "type: tanh derivative result: \n");
fbc::activation_function_tanh_derivative(dst.data(), dst.data(), length);
fbc::print_matrix(dst);
}
执行结果如下:
Python和PyTorch实现如下:
import numpy as np
import torch
data = [1.1, -2.2, 3.3, 0.4, -0.5, -1.6]
# numpy impl
def tanh(x):
lists = list()
for i in range(len(x)):
lists.append((np.exp(x[i]) - np.exp(-x[i])) / (np.exp(x[i]) + np.exp(-x[i])))
return lists
def tanh_derivative(x):
return 1 - np.power(tanh(x), 2)
output = [round(value, 4) for value in tanh(data)] # 通过round保留小数点后4位
print("numpy tanh:", output)
print("numpt tanh derivative:", [round(value, 4) for value in tanh_derivative(data)])
print("numpt tanh derivative2:", [round(1. - value*value, 4) for value in tanh(data)])
# call pytorch interface
input = torch.FloatTensor(data)
m = torch.nn.Tanh()
output2 = m(input)
print("pytorch tanh:", output2)
print("pytorch tanh derivative:", 1. - output2*output2)
执行结果如下:
由以上执行结果可知:C++、Python、PyTorch三种实现方式结果完全一致。
GitHub:
https://github.com/fengbingchun/NN_Test
https://github.com/fengbingchun/PyTorch_Test
标签:src,tanh,int,dst,C++,PyTorch,print,data 来源: https://blog.csdn.net/fengbingchun/article/details/119202855