全连接单次更新
作者:互联网
import numpy as np
learning_rate = 0.001
x1 = x2 = x3 = 1
y1 = y2 = 1
w14 = w15 = w16 = w24 = w25 = w26 = w34 = w35 = w36 = w17 = w27 = w37 = 1
w48 = w49 = w58 = w59 = w68 = w69 = w78 = w79 = 1
X = np.array([x1, x2, x3]).reshape(3, 1)
W1 = np.array([[w14, w24, w34],
[w15, w25, w35],
[w16, w26, w36],
[w17, w27, w37]])
W2 = np.array([[w48, w58, w68, w78],
[w49, w59, w69, w79]])
Y = np.array([y1, y2]).reshape(2, 1)
# 隐藏层
def hidden(W1, X):
x_input = np.dot(W1, X)
A = 1 / (1 + np.exp(-x_input))
return A
# 输出层
def output(W2, A):
x_input = np.dot(W2, A)
Y = 1 / (1 + np.exp(-x_input))
return Y
# 梯度
def deta(y_pre, y):
result = y_pre * (1 - y) * (y - y_pre)
return result
# 权重更新
def update(W, learning_rate, deta, X):
result = W + learning_rate * deta * (X.T)
return result
# A = hidden(W1, X)
# y_pre = output(W2, A)
# --------------- 更新W2-----------------
# deta_3 = y_pre * (1 - y_pre) * (Y - y_pre)
# W2 = update(W2, learning_rate, deta_3, A)
# print("更新后W2:", W2)
# --------------- 更新W1-----------------
# print('A:', A)
# print('W2:', W2.T)
# print('deta_3:', deta_3)
#
#
# n_deta_3 = np.vstack((deta_3,deta_3))
# print(n_deta_3,'-------------')
#
# deta_2 = A * (1 - A) * (W2.T) * n_deta_3
# W1 = update(W1, learning_rate, deta_2, X)
# print("更新后W1:", W1)
for i in range(100):
A = hidden(W1, X)
y_pre = output(W2, A)
# --------------- 更新W2-----------------
deta_3 = y_pre * (1 - y_pre) * (Y - y_pre)
W2 = update(W2, learning_rate, deta_3, A)
print("更新后W2:", W2)
标签:pre,deta,更新,print,W2,W1,np,单次,连接 来源: https://blog.csdn.net/weixin_51062176/article/details/120450908