其他分享
首页 > 其他分享> > tf.strided_slice_and_tf.fill_and_tf.concat

tf.strided_slice_and_tf.fill_and_tf.concat

作者:互联网

tf.strided_slice,tf.fill,tf.concat使用实例

 其中,我们需要对tensor data进行切片,tf.strided_slice使用方法请参考

import tensorflow as tf

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# process_decoder_input
data = tf.constant(
	[
		[4, 5, 20, 20, 22, 3], [17, 19, 28, 8, 7, 3], [5, 13, 15, 24, 26, 3], [5, 20, 25, 4, 5, 3],
		[4, 12, 14, 15, 5, 3], [4, 7, 7, 16, 23, 3], [7, 8, 10, 13, 19, 3]
	])

batch_size = 6
ending = tf.strided_slice(data, [0, 0], [6, -1], [1, 1])
fill = tf.fill([6, 1], 2)
decoder_input = tf.concat([tf.fill([batch_size, 1], 2), ending], 1)


# Decoder
# 先对target数据进行预处理
def process_decoder_input(data, vocab_to_int, batch_size):
	"""
	补充<GO>,并移除最后一个字符
	"""
	# cut掉最后一个字符
	ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])
	fill = tf.fill([batch_size, 1], vocab_to_int['<GO>'])

	# vocab_to_int['<GO>']在本例中是2,经过在列维度上的合并,每个序列都是以GO(对应数值为2)开头
	decoder_input = tf.concat([fill, ending], 1)

	return ending, fill, decoder_input


data = tf.constant(
	[
		[4, 5, 20, 20, 22, 3],
		[17, 19, 28, 8, 7, 3],
		[5, 13, 15, 24, 26, 3],
		[5, 20, 25, 4, 5, 3],
		[4, 12, 14, 15, 5, 3],
		[4, 7, 7, 16, 23, 3],
		[7, 8, 10, 13, 19, 3]
	]
)

target_letter_to_int = {
	'<PAD>': 0, '<UNK>': 1, '<GO>': 2, '<EOS>': 3,
	'a': 4, 'b': 5, 'c': 6, 'd': 7, 'e': 8, 'f': 9, 'g': 10, 'h': 11, 'i': 12, 'j': 13, 'k': 14, 'l': 15, 'm': 16,
	'n': 17, 'o': 18, 'p': 19, 'q': 20, 'r': 21, 's': 22, 't': 23, 'u': 24, 'v': 25, 'w': 26, 'x': 27, 'y': 28, 'z': 29}
batch_size = 6

ending, fill, decoder_input = process_decoder_input(data, target_letter_to_int, batch_size)

with tf.Session() as sess:  # 初始化会话
	sess.run(tf.global_variables_initializer())
	print('ending:\n', sess.run(ending))
	print('fill:\n', sess.run(fill))
	print('decoder_input:\n', sess.run(decoder_input))

  结果如下:

'''
ending:
 [[ 4  5 20 20 22]
 [17 19 28  8  7]
 [ 5 13 15 24 26]
 [ 5 20 25  4  5]
 [ 4 12 14 15  5]
 [ 4  7  7 16 23]]
fill:
 [[2]
 [2]
 [2]
 [2]
 [2]
 [2]]
decoder_input:
 [[ 2  4  5 20 20 22]
 [ 2 17 19 28  8  7]
 [ 2  5 13 15 24 26]
 [ 2  5 20 25  4  5]
 [ 2  4 12 14 15  5]
 [ 2  4  7  7 16 23]]
'''

  

标签:slice,20,ending,decoder,tf,input,concat,fill
来源: https://www.cnblogs.com/always-fight/p/12571247.html