Quantcast
Channel: Sam的技术Blog
Viewing all articles
Browse latest Browse all 158

Tensorflow中的shape和rank

$
0
0
作者: Sam (甄峰)   sam_code@hotmail.com

TensorFlow中的数据以张量(Tensor)形式出现,张量(Tensor)可以看成一个多维数组。而shape则代表张量的形状。rank则表示Tensor的维度

1. tf.shape()讲解:
tf.shape(
    input
,
    name
=None,
    out_type
=tf.dtypes.int32
)

这个函数获取参数一(input)这个张量的形状(shape),并以整形一维array的形式提供出来。

这个形状比较特别,需要仔细查看, 以下为例子:
import tensorflow as tf

print("SamInfo")

c1 = tf.constant(value=[[[2., 4., 5.,4.],[3.,5., 5.,4.],[3.,5., 5.,4.]], [[2., 4., 5.,4.],[3.,5., 5.,4.], [3.,5., 5.,4.]]], dtype=tf.float32)
c2 = tf.constant([[2], [2]] ) 
c3 = tf.constant([1,2,3,4])


with tf.Session() as sess:
    print (sess.run(tf.shape(c1)))
    print (sess.run(tf.shape(c2)))
    print(sess.run(tf.shape(c3)))


输出为:

SamInfo
[2 3 4]
[2 1]
[4]



解释如下:

c3 = tf.constant([1,2,3,4])

c3这个常量是:[1,2,3,4]组成。在[]最外层,有4个单元,则shape为[4]



c2 = tf.constant([[2], [2]] ) 

c2这个常亮,最外层[]内有两个单元,而每个单元内,又各自有一个单元。所以shape为[2 1]



c1 = tf.constant(value=[[[2., 4., 5.,4.],[3.,5., 5.,4.],[3.,5., 5.,4.]], [[2., 4., 5.,4.],[3.,5., 5.,4.], [3.,5., 5.,4.]]], dtype=tf.float32)


c3这个常亮,最外层[]内有两个单元,而每个单元内,又各自有3个单元。最内层单元内,则有4个单元。所以shape为[2 3 4]





2. shape的使用:

2.1. 直接输入1-D array.

import tensorflow as tf



v2 = tf.random_uniform([3,3,2], minval= -1, maxval=1)

init = tf.global_variables_initializer()

with tf.Session() as sess: sess.run(init) print("Tensor is:", sess.run(v2))

将出现3个单元,每个单元又3个单元,最小的单元包含2个数字。

Tensor is: [[[-0.84239459 -0.64019585]
  [-0.20693684  0.61758161]
  [-0.39858079 -0.55873823]]

 [[ 0.63348198  0.6409862 ]
  [ 0.59288216 -0.47363114]
  [-0.80493784  0.34760666]]

 [[-0.63808179 -0.62189484]
  [-0.97743845  0.87292027]
  [-0.83223104  0.27168941]]]



2.2. 直接仿照其它Tensor的shape.

v2 = tf.random_uniform([3,3,2], minval= -1, maxval=1)
v3 = tf.zeros(tf.shape(v2))



3. Tensor的rank:

Tensor类似一个多维数组,那它的维度,就有rank表示。



c1 = tf.constant(value=[[[2., 4., 5.,4.],[3.,5., 5.,4.],[3.,5., 5.,4.]], [[2., 4., 5.,4.],[3.,5., 5.,4.], [3.,5., 5.,4.]]], dtype=tf.float32)



print ("C1 is:",sess.run(tf.shape(c1)), "rank is:", sess.run(tf.rank(c1)))




C1 is: [2 3 4] rank is: 3
表明它是个3-D Tensor。







 

Viewing all articles
Browse latest Browse all 158

Trending Articles