其他分享
首页 > 其他分享> > lternating Direction Method of Multiplier(ADMM) Algorithm

lternating Direction Method of Multiplier(ADMM) Algorithm

作者:互联网

Alternating Direction Method of Multipliers (ADMM) 是一种通过将凸优化问题分解为一系列的易解子问题进行求解的算法,目前它在很多领域得到了广泛的应用。 [2].

This is simplified version, specifically for the LASSO:

给定一个稀疏向量xRnx\in R^nx∈Rn和矩阵ARm×nA\in R^{m\times n}A∈Rm×n,
y=Ax+ey=Ax+ey=Ax+e
其中eee是加性高斯白噪声。为了恢复信号xxx,我们求解如下最小化问题
x^=minxyAx22+λx1 \hat{x} = \min_x ||y-Ax||_2^2 + \lambda||x||_1 x^=xmin​∣∣y−Ax∣∣22​+λ∣∣x∣∣1​


在求解过程中,迭代地计算如下两个式子,直到满足收敛条件。
xk+1=(ATA+ρI)1(ATy+ρ(zu)) x^{k+1} = (A^TA + \rho I )^{-1}(A^Ty + \rho (z - u))xk+1=(ATA+ρI)−1(ATy+ρ(z−u))
zk+1=sign(x^)max(0,xλρ) z^{k+1} = \mathrm{sign}(\hat{x})\mathrm{max}\left(0, |x| - \frac{\lambda}{\rho}\right) zk+1=sign(x^)max(0,∣x∣−ρλ​)

下面是ADMM算法的PYTHON实现方式。 (http://stanford.edu/~boyd/admm.html).

import numpy as np
import matplotlib.pyplot as plt
from math import sqrt, log

def Sthresh(x, gamma):
    return np.sign(x)*np.maximum(0, np.absolute(x)-gamma/2.0)

def ADMM(A, y):

    m, n = A.shape
    w, v = np.linalg.eig(A.T.dot(A))
    MAX_ITER = 10000

    # Function to caluculate min 1/2(y - Ax) + l||x||
    # via alternating direction methods
    xhat = np.zeros([n, 1])
    zhat = np.zeros([n, 1])
    u = np.zeros([n, 1])

    # Calculate regression co-efficient and stepsize
    lamb = sqrt(2*log(n, 10))
    rho = 1/(np.amax(np.absolute(w)))

    # Pre-compute to save some multiplications
    AtA = A.T.dot(A)
    Aty = A.T.dot(y)
    Q = AtA + rho*np.identity(n)
    Q = np.linalg.inv(Q)

    for i in np.arange(1, MAX_ITER + 1):

        # x minimisation step via posterier OLS
        xhat = Q.dot(Aty + rho*(zhat - u))

        # z minimisation via soft-thresholding
        zhat = Sthresh(xhat + u, lamb/rho)

        # mulitplier update
        u = u + xhat - zhat

    return zhat, rho, lamb

def test(m=50, n=200):
    """Test the ADMM method with randomly generated matrices and vectors"""
    A = np.random.randn(m, n)

    num_non_zeros = 10
    positions = np.random.randint(0, n, num_non_zeros)
    amplitudes = 100*np.random.randn(num_non_zeros, 1)
    x = np.zeros((n, 1))
    x[positions] = amplitudes

    y = A.dot(x) + np.random.randn(m, 1)

    xhat, rho, lamb = ADMM(A, y)

    plt.plot(x, label='Original')
    plt.plot(xhat, label = 'Estimate')

    plt.legend(loc = 'upper right')

    plt.show()


if __name__ == "__main__":
    test()

参考文献:
[1] https://codereview.stackexchange.com/questions/108263/alternating-direction-method-of-multipliers
[2] http://stanford.edu/~boyd/admm.html

标签:ADMM,Direction,Algorithm,xhat,zeros,rho,np,Ax
来源: https://blog.csdn.net/x5675602/article/details/88993976