作者: Sam (甄峰)
sam_code@hotmail.com
x_data_2 =
[[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18]]
sp = tf.split(x_data_2, 2, axis=0)
with tf.Session() as sess:
for c in sp:
print("split:",sess.run(c))
split: [[1 2 3] [4 5 6] [7 8 9]] split: [[10 11 12] [13 14 15] [16 17 18]]
x_data_2 = [[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18]]
sp = tf.split(x_data_2, [1,2], axis=1)
with tf.Session() as sess:
for c in sp:
print("split:",sess.run(c))
split: [[ 1] [ 4] [ 7] [10] [13] [16]] split: [[ 2 3] [ 5 6] [ 8 9] [11 12] [14 15] [17 18]]
常见的用法:
例如:有数据为NHWC--[batch, height, width, channels]
如果想把各个chanels的数据单独拿出来,则可以使用tf.split()