编程语言
首页 > 编程语言> > python 笔记:dtw包

python 笔记:dtw包

作者:互联网

1 作用

用来辅助计算DTW的python模块

2 基本使用方法

2.1 数据

假设有两个序列

import numpy as np


x = np.array([1,3,2,4,2])
y = np.array([0,3,4,2,2])
plt.plot(x,'green')
plt.plot(y,'blue')
plt.legend(['x','y'])
plt.show()

 

我们要计算这两个序列之间的dtw

 2.2 定义距离函数

我们首先要定义两个序列间任意两个点xi,yj之间的距离 

manhattan_distance = lambda x, y: np.abs(x - y)

2.3 使用dtw

from dtw import dtw
d, cost_matrix, acc_cost_matrix, path = dtw(x, y, dist=manhattan_distance)

2.4 返回参数意义

d就是两个序列间dtw的值,算出来是2,和DTW _UQI-LIUWJ的博客 一致

cost_matrix就是 用前面的manhattan_distance 算出来的两个序列之间的两两距离

cost_matrix
'''
array([[1., 2., 3., 1., 1.],
       [3., 0., 1., 1., 1.],
       [2., 1., 2., 0., 0.],
       [4., 1., 0., 2., 2.],
       [2., 1., 2., 0., 0.]])
'''

和之前手动算的一致

 acc_cost_matrix 也就是 DTW _UQI-LIUWJ的博客 的dp矩阵

 

acc_cost_matrix
'''
array([[ 1.,  3.,  6.,  7.,  8.],
       [ 4.,  1.,  2.,  3.,  4.],
       [ 6.,  2.,  3.,  2.,  2.],
       [10.,  3.,  2.,  4.,  4.],
       [12.,  4.,  4.,  2.,  2.]])
'''

 和之前手动算的一致

path就是对应关系

path
#(array([0, 1, 2, 3, 4, 4]), array([0, 1, 1, 2, 3, 4]))
plt.imshow(cost_matrix.T,origin='lower',cmap='gray')
plt.plot(path[0],path[1])
plt.show()

 每次沿着颜色最深的点走 

 

标签:plt,matrix,python,笔记,cost,path,array,dtw
来源: https://blog.csdn.net/qq_40206371/article/details/122644736