tf.gather,tf.gather_nd,tf.boolean_mask
作者:互联网
函数定义链接:
tf.gather:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather
tf.gather_nd:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather_nd
tf.boolean_mask:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/boolean_mask
区别:
1.tf.gather
tf.gather(
params, indices, validate_indices=None, name=None, axis=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
indices维度[i1,i2,....]
axis:指定维度
Output: 根据indices的数值,从params第axis维获取数据,输出数据维度[...,p(axis-1),i1,i2,...,p(axis+1),...]
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.gather(data,[0,1],axis=0)
'''
<tf.Tensor: id=1375, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>
'''
#case 2:
tf.gather(data,[0,1],axis=1)
'''
<tf.Tensor: id=1378, shape=(2, 2, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[12, 13, 14, 15],
[16, 17, 18, 19]]])>
'''
#case 3:
tf.gather(data,[[0,1]],axis=1)
'''
<tf.Tensor: id=1381, shape=(2, 1, 2, 4), dtype=int32, numpy=
array([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]]],
[[[12, 13, 14, 15],
[16, 17, 18, 19]]]])>
'''
2.tf.gather_nd
tf.gather_nd(
params, indices, name=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
indices维度[i1,i2,....]
Output: 根据indices,获取params对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.gather_nd(data,[0])
'''
<tf.Tensor: id=1383, shape=(3, 4), dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])>
'''
#case 2:
tf.gather(data,[0,1])
'''
<tf.Tensor: id=1385, shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7])>
'''
#case 3:
tf.gather(data,[[0,1]])
'''
<tf.Tensor: id=1387, shape=(1, 4), dtype=int32, numpy=array([[4, 5, 6, 7]])>
'''
3.tf.boolean_mask
tf.boolean_mask(
tensor, mask, name='boolean_mask', axis=None
)
Input: tensor维度[p1,p2,p3,p4,....,p(n)]
mask:二值化,维度[p(axis),p(axis+1),...,p(axis+i)],axis+i<n,(注,mask维度需要跟tensor维度对应)
axis:从该维度开始
Output: 根据mask,获取tensor对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])>'''
#case 1:
tf.boolean_mask(data,[True,False],axis=0)
'''
<tf.Tensor: id=1451, shape=(1, 3, 4), dtype=int32, numpy=
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])>
'''
#case 2: 维度不对应
tf.boolean_mask(data,[True],axis=0)
'''
ValueError: Shapes (2,) and (1,) are incompatible
'''
#case 3:
tf.boolean_mask(data,[[True,False,False],[True,False,True]],axis=0)
'''
<tf.Tensor: id=1482, shape=(3, 4), dtype=int32, numpy=
array([[ 0, 1, 2, 3],
[12, 13, 14, 15],
[20, 21, 22, 23]])>
'''
标签:case,gather,mask,tf,data,axis 来源: https://blog.csdn.net/u014426939/article/details/117511948