其他分享
首页 > 其他分享> > 全连接单次更新

全连接单次更新

作者:互联网


import numpy as np

learning_rate = 0.001
x1 = x2 = x3 = 1
y1 = y2 = 1
w14 = w15 = w16 = w24 = w25 = w26 = w34 = w35 = w36 = w17 = w27 = w37 = 1
w48 = w49 = w58 = w59 = w68 = w69 = w78 = w79 = 1

X = np.array([x1, x2, x3]).reshape(3, 1)
W1 = np.array([[w14, w24, w34],
               [w15, w25, w35],
               [w16, w26, w36],
               [w17, w27, w37]])

W2 = np.array([[w48, w58, w68, w78],
               [w49, w59, w69, w79]])

Y = np.array([y1, y2]).reshape(2, 1)


# 隐藏层
def hidden(W1, X):
    x_input = np.dot(W1, X)
    A = 1 / (1 + np.exp(-x_input))
    return A


# 输出层
def output(W2, A):
    x_input = np.dot(W2, A)
    Y = 1 / (1 + np.exp(-x_input))
    return Y


# 梯度
def deta(y_pre, y):
    result = y_pre * (1 - y) * (y - y_pre)
    return result


# 权重更新
def update(W, learning_rate, deta, X):
    result = W + learning_rate * deta * (X.T)
    return result


# A = hidden(W1, X)
# y_pre = output(W2, A)

# --------------- 更新W2-----------------


# deta_3 = y_pre * (1 - y_pre) * (Y - y_pre)
# W2 = update(W2, learning_rate, deta_3, A)
# print("更新后W2:", W2)


# --------------- 更新W1-----------------
# print('A:', A)
# print('W2:', W2.T)
# print('deta_3:', deta_3)
#
#
# n_deta_3 = np.vstack((deta_3,deta_3))
# print(n_deta_3,'-------------')
#
# deta_2 = A * (1 - A) * (W2.T) * n_deta_3
# W1 = update(W1, learning_rate, deta_2, X)
# print("更新后W1:", W1)

for i in range(100):
    A = hidden(W1, X)
    y_pre = output(W2, A)

    # --------------- 更新W2-----------------

    deta_3 = y_pre * (1 - y_pre) * (Y - y_pre)
    W2 = update(W2, learning_rate, deta_3, A)
    print("更新后W2:", W2)

标签:pre,deta,更新,print,W2,W1,np,单次,连接
来源: https://blog.csdn.net/weixin_51062176/article/details/120450908