从Python转到Numpy(三)
作者:互联网
代码的矢量化意味着要解决的问题本质上是可矢量化的,只需要一些 numpy 技巧即可使代码运行更快,但是矢量化并不是十分容易。
矢量化代码样例:生命游戏(Game of Life)
生命游戏的宇宙是一个二维正交网格,每个格子(细胞)处于两种可能的状态,生或死。每个位于格子里的细胞都与它的八个相邻格子的细胞(水平、垂直或对角相邻的细胞)相互作用。在每个进化步骤中,都会发生以下转换:
- 少于两个邻居的活细胞都会死亡。
- 超过三个邻居的活细胞都会死亡。
- 有两个或三个邻居的活细胞存活,可以保持到下一代。
- 刚好有三个活细胞邻居的死细胞复活
Python 实现
格子使用二维列表表示系统的初始状态,1代表存活,0代表死亡。
Z = [[0,0,0,0,0,0],
[0,0,0,1,0,0],
[0,1,0,1,0,0],
[0,0,1,1,0,0],
[0,0,0,0,0,0],
[0,0,0,0,0,0]]
为简化计算,边界全部设定为0。先定义函数统计一下每个格子相邻的活细胞数量(不计算边界的格子)。
def compute_neighbours(Z):
shape = len(Z), len(Z[0])
N = [[0,]*(shape[0]) for i in range(shape[1])]
# 遍历每一个元素,不包含边界的格子,计算周边元素的之和
for x in range(1,shape[0]-1):
for y in range(1,shape[1]-1):
N[x][y] = Z[x-1][y-1]+Z[x][y-1]+Z[x+1][y-1] \
+ Z[x-1][y] +Z[x+1][y] \
+ Z[x-1][y+1]+Z[x][y+1]+Z[x+1][y+1]
return N
# 测试一下统计结果
compute_neighbours(Z)
out:
[[0, 0, 0, 0, 0, 0],
[0, 1, 3, 1, 2, 0],
[0, 1, 5, 3, 3, 0],
[0, 2, 3, 2, 2, 0],
[0, 1, 2, 2, 1, 0],
[0, 0, 0, 0, 0, 0]]
定义迭代函数,按照上面的四个规则进行下一步的迭代。
def iterate(Z):
shape = len(Z), len(Z[0])
N = compute_neighbours(Z) # 统计相邻元素的活细胞
for x in range(1,shape[0]-1):
for y in range(1,shape[1]-1):
# 如果自身是活细胞 而且邻居的活细胞数量小于2或者大于3,转为死细胞
if Z[x][y] == 1 and (N[x][y] < 2 or N[x][y] > 3):
Z[x][y] = 0
# 否则自身是死细胞而且邻居活细胞数量为3,转为活细胞
elif Z[x][y] == 0 and N[x][y] == 3:
Z[x][y] = 1
return Z
下边是经过5次迭代的可视化结果(不包含边界数据)
完整代码:
def compute_neighbours(Z):
shape = len(Z), len(Z[0])
N = [[0, ]*(shape[0]) for i in range(shape[1])]
for x in range(1, shape[0]-1):
for y in range(1, shape[1]-1):
N[x][y] = Z[x-1][y-1]+Z[x][y-1]+Z[x+1][y-1] \
+ Z[x-1][y] +Z[x+1][y] \
+ Z[x-1][y+1]+Z[x][y+1]+Z[x+1][y+1]
return N
def iterate(Z):
shape = len(Z), len(Z[0])
N = compute_neighbours(Z)
for x in range(1, shape[0]-1):
for y in range(1, shape[1]-1):
if Z[x][y] == 1 and (N[x][y] < 2 or N[x][y] > 3):
Z[x][y] = 0
elif Z[x][y] == 0 and N[x][y] == 3:
Z[x][y] = 1
return Z
if __name__ == '__main__':
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
Z = [[0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 0],
[0, 0, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]
figure = plt.figure(figsize=(12, 3))
labels = ("Initial state",
"iteration 1", "iteration 2",
"iteration 3", "iteration 4")
for i in range(5): # 经过5次迭代
ax = plt.subplot(1, 5, i+1, aspect=1, frameon=False)
for x in range(1, 5):
for y in range(1, 5):
if Z[x][y] == 1:
facecolor = 'black' # 黑色代表存活
else:
facecolor = 'white' # 白色代表死亡
rect = Rectangle((x, 5-y), width=0.9, height=0.9,
linewidth=1.0, edgecolor='black',
facecolor=facecolor)
ax.add_patch(rect)
ax.set_xlim(.9, 5.1)
ax.set_ylim(.9, 5.1)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel(labels[i])
for tick in ax.xaxis.get_major_ticks():
tick.tick1On = tick.tick2On = False
for tick in ax.yaxis.get_major_ticks():
tick.tick1On = tick.tick2On = False
iterate(Z)
plt.tight_layout()
plt.show()
Numpy版本的实现
观察纯python 版本的实现,统计相邻细胞数和迭代函数都使用了嵌套循环,计算效率是比较低下的,尤其是计算大规模数组时,这一点表现的更为明显。矢量化的主要目的就是尽可能消除循环计算,因此首先考虑如何通过矢量化计算解决相邻细胞数的问题。
问题是怎样一次性获得所有相邻元素?对于一维度数组,这个问题就变成如何一次性获得左右相邻两组元素?如下图:
每个元素(不含边界元素)的全部左邻表示为Z[:-2],同理全部右邻表示为Z[2:],这里对应是不考虑边界元素的,说白了,左边界的左邻或者右边界的右邻不存在。
二位数组的扩展:
Z = np.array(
[[0,0,0,0,0,0],
[0,0,0,1,0,0],
[0,1,0,1,0,0],
[0,0,1,1,0,0],
[0,0,0,0,0,0],
[0,0,0,0,0,0]]
)
N = (
Z[ :-2, :-2] + Z[ :-2, 1:-1] + Z[ :-2, 2:] +
Z[1:-1, :-2] + Z[1:-1, 2:] +
Z[2: , :-2] + Z[2: , 1:-1] + Z[2: , 2:]
)
print(N)
out:
[[1 3 1 2]
[1 5 3 3]
[2 3 2 2]
[1 2 2 1]]
与python版嵌套循环计算结果一致,但少了边界元素,需要补充回来,这点在numpy里就比较简单了,用zeros赋值完成。矢量化计算的好处是省去了嵌套循环,一个加法搞定。
接下来,消除迭代中的循环,把上面4个规则转为numpy实现。
# 第一版numpy实现
# 扁平化数组
N_ = N.ravel()
Z_ = Z.ravel()
# 实现规则
R1 = np.argwhere( (Z_==1) & (N_ < 2) )
R2 = np.argwhere( (Z_==1) & (N_ > 3) )
R3 = np.argwhere( (Z_==1) & ((N_==2) | (N_==3)) )
R4 = np.argwhere( (Z_==0) & (N_==3) )
# 根据规则赋值
Z_[R1] = 0
Z_[R2] = 0
Z_[R3] = Z_[R3]
Z_[R4] = 1
# 设置边界元素为0
Z[0,:] = Z[-1,:] = Z[:,0] = Z[:,-1] = 0
虽然第一版实现省去了嵌套循环,但是进行了4次argwhere的调用会导致性能下降。替代办法是使用numpy的布尔操作。
birth = (N==3)[1:-1,1:-1] & (Z[1:-1,1:-1]==0) # 复生条件
survive = ((N==2) | (N==3))[1:-1,1:-1] & (Z[1:-1,1:-1]==1) # 存活条件
Z[...] = 0 # 数组清零
Z[1:-1,1:-1][birth | survive] = 1 # 把符合生存条件的元素置为1
基于上边的优化逻辑,加大数组规模,并配合动画观察一下升级版的Game of life
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
def update(*args):
global Z, M
N = (Z[0:-2, 0:-2] + Z[0:-2, 1:-1] + Z[0:-2, 2:] +
Z[1:-1, 0:-2] + Z[1:-1, 2:] +
Z[2: , 0:-2] + Z[2: , 1:-1] + Z[2: , 2:])
birth = (N == 3) & (Z[1:-1, 1:-1] == 0)
survive = ((N == 2) | (N == 3)) & (Z[1:-1, 1:-1] == 1)
Z[...] = 0
Z[1:-1, 1:-1][birth | survive] = 1
# 显示过去迭代过程
M[M>0.25] = 0.25 # 阈值,控制显示元素灰度值
M *= 0.995 # 灰度衰减系数,可以自行修改观察
M[Z==1] = 1
im.set_data(M)
Z = np.random.randint(0, 2, (300, 600))
M = np.zeros(Z.shape)
size = np.array(Z.shape)
dpi = 80.0
figsize = size[1]/float(dpi), size[0]/float(dpi)
fig = plt.figure(figsize=figsize, dpi=dpi)
fig.add_axes([0.0, 0.0, 1.0, 1.0], frameon=False)
im = plt.imshow(M, interpolation='nearest', cmap=plt.cm.gray_r, vmin=0, vmax=1)
plt.xticks([]), plt.yticks([])
#matplotlib动画函数,绘图函数update,调用间隔10ms,2000祯结束。详细参数设置参考文档
animation = FuncAnimation(fig, update, interval=10, frames=2000)
plt.show()
标签:plt,Python,range,细胞,shape,转到,np,ax,Numpy 来源: https://blog.csdn.net/tommystudio/article/details/117983951