编程语言
首页 > 编程语言> > 从Python转到Numpy(三)

从Python转到Numpy(三)

作者:互联网

代码的矢量化意味着要解决的问题本质上是可矢量化的,只需要一些 numpy 技巧即可使代码运行更快,但是矢量化并不是十分容易。

矢量化代码样例:生命游戏(Game of Life)

生命游戏的宇宙是一个二维正交网格,每个格子(细胞)处于两种可能的状态,生或死。每个位于格子里的细胞都与它的八个相邻格子的细胞(水平、垂直或对角相邻的细胞)相互作用。在每个进化步骤中,都会发生以下转换:

  1. 少于两个邻居的活细胞都会死亡。
  2. 超过三个邻居的活细胞都会死亡。
  3. 有两个或三个邻居的活细胞存活,可以保持到下一代。
  4. 刚好有三个活细胞邻居的死细胞复活

 

Python 实现

格子使用二维列表表示系统的初始状态,1代表存活,0代表死亡。

Z = [[0,0,0,0,0,0],
     [0,0,0,1,0,0],
     [0,1,0,1,0,0],
     [0,0,1,1,0,0],
     [0,0,0,0,0,0],
     [0,0,0,0,0,0]]

为简化计算,边界全部设定为0。先定义函数统计一下每个格子相邻的活细胞数量(不计算边界的格子)。

def compute_neighbours(Z):
    shape = len(Z), len(Z[0])
    N  = [[0,]*(shape[0]) for i in range(shape[1])]
    # 遍历每一个元素,不包含边界的格子,计算周边元素的之和
    for x in range(1,shape[0]-1): 
        for y in range(1,shape[1]-1):
            N[x][y] = Z[x-1][y-1]+Z[x][y-1]+Z[x+1][y-1] \
                    + Z[x-1][y]            +Z[x+1][y]   \
                    + Z[x-1][y+1]+Z[x][y+1]+Z[x+1][y+1]
    return N
# 测试一下统计结果
compute_neighbours(Z)

out:
[[0, 0, 0, 0, 0, 0],
 [0, 1, 3, 1, 2, 0],
 [0, 1, 5, 3, 3, 0],
 [0, 2, 3, 2, 2, 0],
 [0, 1, 2, 2, 1, 0],
 [0, 0, 0, 0, 0, 0]]

定义迭代函数,按照上面的四个规则进行下一步的迭代。

def iterate(Z):
    shape = len(Z), len(Z[0])
    N = compute_neighbours(Z) # 统计相邻元素的活细胞
    for x in range(1,shape[0]-1):
        for y in range(1,shape[1]-1):
             # 如果自身是活细胞 而且邻居的活细胞数量小于2或者大于3,转为死细胞
             if Z[x][y] == 1 and (N[x][y] < 2 or N[x][y] > 3):
                 Z[x][y] = 0
             # 否则自身是死细胞而且邻居活细胞数量为3,转为活细胞
             elif Z[x][y] == 0 and N[x][y] == 3:
                 Z[x][y] = 1
    return Z

下边是经过5次迭代的可视化结果(不包含边界数据)

完整代码:

def compute_neighbours(Z):
    shape = len(Z), len(Z[0])
    N = [[0, ]*(shape[0]) for i in range(shape[1])]
    for x in range(1, shape[0]-1):
        for y in range(1, shape[1]-1):
            N[x][y] = Z[x-1][y-1]+Z[x][y-1]+Z[x+1][y-1] \
                    + Z[x-1][y]            +Z[x+1][y]   \
                    + Z[x-1][y+1]+Z[x][y+1]+Z[x+1][y+1]
    return N


def iterate(Z):
    shape = len(Z), len(Z[0])
    N = compute_neighbours(Z)
    for x in range(1, shape[0]-1):
        for y in range(1, shape[1]-1):
            if Z[x][y] == 1 and (N[x][y] < 2 or N[x][y] > 3):
                Z[x][y] = 0
            elif Z[x][y] == 0 and N[x][y] == 3:
                Z[x][y] = 1
    return Z

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle

    Z = [[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0],
         [0, 1, 0, 1, 0, 0],
         [0, 0, 1, 1, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]]

    figure = plt.figure(figsize=(12, 3))

    labels = ("Initial state",
              "iteration 1", "iteration 2",
              "iteration 3", "iteration 4")

    for i in range(5): # 经过5次迭代
        ax = plt.subplot(1, 5, i+1, aspect=1, frameon=False)

        for x in range(1, 5):
            for y in range(1, 5):
                if Z[x][y] == 1:
                    facecolor = 'black' # 黑色代表存活
                else:
                    facecolor = 'white' # 白色代表死亡
                rect = Rectangle((x, 5-y), width=0.9, height=0.9,
                                 linewidth=1.0, edgecolor='black',
                                 facecolor=facecolor)
                ax.add_patch(rect)
        ax.set_xlim(.9, 5.1)
        ax.set_ylim(.9, 5.1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(labels[i])

        for tick in ax.xaxis.get_major_ticks():
            tick.tick1On = tick.tick2On = False
        for tick in ax.yaxis.get_major_ticks():
            tick.tick1On = tick.tick2On = False

        iterate(Z)

    plt.tight_layout()
    plt.show()

 

Numpy版本的实现

观察纯python 版本的实现,统计相邻细胞数和迭代函数都使用了嵌套循环,计算效率是比较低下的,尤其是计算大规模数组时,这一点表现的更为明显。矢量化的主要目的就是尽可能消除循环计算,因此首先考虑如何通过矢量化计算解决相邻细胞数的问题。

问题是怎样一次性获得所有相邻元素?对于一维度数组,这个问题就变成如何一次性获得左右相邻两组元素?如下图:

每个元素(不含边界元素)的全部左邻表示为Z[:-2],同理全部右邻表示为Z[2:],这里对应是不考虑边界元素的,说白了,左边界的左邻或者右边界的右邻不存在。

二位数组的扩展:

Z = np.array(
    [[0,0,0,0,0,0],
     [0,0,0,1,0,0],
     [0,1,0,1,0,0],
     [0,0,1,1,0,0],
     [0,0,0,0,0,0],
     [0,0,0,0,0,0]]
)

N = (
    Z[ :-2,  :-2] + Z[ :-2, 1:-1] + Z[ :-2, 2:] +
    Z[1:-1,  :-2]                 + Z[1:-1, 2:] +
    Z[2:  ,  :-2] + Z[2:  , 1:-1] + Z[2:  , 2:]
)

print(N)

out:

[[1 3 1 2]
 [1 5 3 3]
 [2 3 2 2]
 [1 2 2 1]]

与python版嵌套循环计算结果一致,但少了边界元素,需要补充回来,这点在numpy里就比较简单了,用zeros赋值完成。矢量化计算的好处是省去了嵌套循环,一个加法搞定。

 

接下来,消除迭代中的循环,把上面4个规则转为numpy实现。

# 第一版numpy实现
# 扁平化数组
N_ = N.ravel()
Z_ = Z.ravel()

# 实现规则
R1 = np.argwhere( (Z_==1) & (N_ < 2) )
R2 = np.argwhere( (Z_==1) & (N_ > 3) )
R3 = np.argwhere( (Z_==1) & ((N_==2) | (N_==3)) )
R4 = np.argwhere( (Z_==0) & (N_==3) )

# 根据规则赋值
Z_[R1] = 0
Z_[R2] = 0
Z_[R3] = Z_[R3]
Z_[R4] = 1

# 设置边界元素为0
Z[0,:] = Z[-1,:] = Z[:,0] = Z[:,-1] = 0

 

虽然第一版实现省去了嵌套循环,但是进行了4次argwhere的调用会导致性能下降。替代办法是使用numpy的布尔操作。

birth = (N==3)[1:-1,1:-1] & (Z[1:-1,1:-1]==0) # 复生条件
survive = ((N==2) | (N==3))[1:-1,1:-1] & (Z[1:-1,1:-1]==1) # 存活条件
Z[...] = 0 # 数组清零
Z[1:-1,1:-1][birth | survive] = 1 # 把符合生存条件的元素置为1

基于上边的优化逻辑,加大数组规模,并配合动画观察一下升级版的Game of life

 

 

 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation


def update(*args):
    global Z, M

    N = (Z[0:-2, 0:-2] + Z[0:-2, 1:-1] + Z[0:-2, 2:] +
         Z[1:-1, 0:-2]                 + Z[1:-1, 2:] +
         Z[2:  , 0:-2] + Z[2:  , 1:-1] + Z[2:  , 2:])
    birth = (N == 3) & (Z[1:-1, 1:-1] == 0)
    survive = ((N == 2) | (N == 3)) & (Z[1:-1, 1:-1] == 1)
    Z[...] = 0
    Z[1:-1, 1:-1][birth | survive] = 1

    # 显示过去迭代过程
    M[M>0.25] = 0.25 # 阈值,控制显示元素灰度值
    M *= 0.995 # 灰度衰减系数,可以自行修改观察
    M[Z==1] = 1
    im.set_data(M)


Z = np.random.randint(0, 2, (300, 600)) 
M = np.zeros(Z.shape)

size = np.array(Z.shape)
dpi = 80.0
figsize = size[1]/float(dpi), size[0]/float(dpi)
fig = plt.figure(figsize=figsize, dpi=dpi)
fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False)
im = plt.imshow(M, interpolation='nearest', cmap=plt.cm.gray_r, vmin=0, vmax=1)
plt.xticks([]), plt.yticks([])

#matplotlib动画函数,绘图函数update,调用间隔10ms,2000祯结束。详细参数设置参考文档
animation = FuncAnimation(fig, update, interval=10, frames=2000)
plt.show()

 

标签:plt,Python,range,细胞,shape,转到,np,ax,Numpy
来源: https://blog.csdn.net/tommystudio/article/details/117983951