其他分享
首页 > 其他分享> > tf.gather,tf.gather_nd,tf.boolean_mask

tf.gather,tf.gather_nd,tf.boolean_mask

作者:互联网

函数定义链接:

tf.gather:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather

tf.gather_nd:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/gather_nd

tf.boolean_mask:https://tensorflow.google.cn/versions/r1.15/api_docs/python/tf/boolean_mask

 

区别

1.tf.gather

tf.gather(
    params, indices, validate_indices=None, name=None, axis=None, batch_dims=0
)

Input: param维度[p1,p2,p3,p4,....]
        indices维度[i1,i2,....]
        axis:指定维度
Output: 根据indices的数值,从params第axis维获取数据,输出数据维度[...,p(axis-1),i1,i2,...,p(axis+1),...]
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.gather(data,[0,1],axis=0)
'''
<tf.Tensor: id=1375, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>
'''

#case 2:
tf.gather(data,[0,1],axis=1)
'''
<tf.Tensor: id=1378, shape=(2, 2, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19]]])>
'''

#case 3:
tf.gather(data,[[0,1]],axis=1)
'''
<tf.Tensor: id=1381, shape=(2, 1, 2, 4), dtype=int32, numpy=
array([[[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]]],
       [[[12, 13, 14, 15],
         [16, 17, 18, 19]]]])>
'''

2.tf.gather_nd

tf.gather_nd(
    params, indices, name=None, batch_dims=0
)
Input: param维度[p1,p2,p3,p4,....]
        indices维度[i1,i2,....]
        
Output: 根据indices,获取params对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.gather_nd(data,[0])
'''
<tf.Tensor: id=1383, shape=(3, 4), dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])>
'''

#case 2:
tf.gather(data,[0,1])
'''
<tf.Tensor: id=1385, shape=(4,), dtype=int32, numpy=array([4, 5, 6, 7])>
'''

#case 3:
tf.gather(data,[[0,1]])
'''
<tf.Tensor: id=1387, shape=(1, 4), dtype=int32, numpy=array([[4, 5, 6, 7]])>
'''

3.tf.boolean_mask

tf.boolean_mask(
    tensor, mask, name='boolean_mask', axis=None
)
Input: tensor维度[p1,p2,p3,p4,....,p(n)]
        mask:二值化,维度[p(axis),p(axis+1),...,p(axis+i)],axis+i<n,(注,mask维度需要跟tensor维度对应)
        axis:从该维度开始

        
Output: 根据mask,获取tensor对应维度的元素并组成Tensor.
data=tf.reshape(tf.range(24),(2,3,4))
'''<tf.Tensor: id=1372, shape=(2, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],
       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])>'''

#case 1:
tf.boolean_mask(data,[True,False],axis=0)
'''
<tf.Tensor: id=1451, shape=(1, 3, 4), dtype=int32, numpy=
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]]])>
'''

#case 2: 维度不对应
tf.boolean_mask(data,[True],axis=0)
'''
ValueError: Shapes (2,) and (1,) are incompatible
'''

#case 3:
tf.boolean_mask(data,[[True,False,False],[True,False,True]],axis=0)
'''
<tf.Tensor: id=1482, shape=(3, 4), dtype=int32, numpy=
array([[ 0,  1,  2,  3],
       [12, 13, 14, 15],
       [20, 21, 22, 23]])>
'''

 

 

 

标签:case,gather,mask,tf,data,axis
来源: https://blog.csdn.net/u014426939/article/details/117511948