Python

python中的.shape(), .shape 和 ten

2019-07-18  本文已影响0人  西北小生_

Python和tensorflow编程中经常见这三种shape的用法,容易混淆,特写一篇文章来总结以备遗忘。这三个函数都是用来获取维度信息的,但用法和使用对象各有不同,下面进行一一介绍。

(1) np.shape()

这个函数是numpy中的一个函数(函数要加括号!!!),其功能是获取括号内数据的长度或维度信息,其使用对象既可以是一个数,也可以是数组或矩阵。如下例所示:

In [1]: import numpy as np                                                      

In [2]: np.shape(0)                                                             
Out[2]: ()

In [3]: np.shape([0])                                                           
Out[3]: (1,)

In [4]: np.shape([1, 2, 3])                                                     
Out[4]: (3,)

In [5]: np.shape([[1], [2]])                                                    
Out[5]: (2, 1)

In [12]: a = np.zeros([2,3])                                                    

In [13]: a                                                                      
Out[13]: 
array([[0., 0., 0.],
       [0., 0., 0.]])

In [14]: np.shape(a)                                                            
Out[14]: (2, 3)

In [15]: np.shape(a)[1]                                                         
Out[15]: 3

(2) array.shape

array.shape是numpy中ndarray数据类型的一个属性。我们先来理解一下几个问题:

1.什么是ndarray数据类型?

ndarray是numpy库中的一种数据类型,凡是以np.array()定义的数据都是ndarray类型,就跟pytorch中的张量tensor类似。

2.什么是属性?

属性就是python类中初始化的时候,self.xx代表的变量,是该类特有的信息。比如我们定义一个学生类:

class Student:
        def __init__(self, height, weight, number):
                self.height = height  # 身高
                self.weight = weight  # 体重
                self.number = number  # 学号

Student类中的self.height,self.weight ,self.number就是属性。如果ZhangSan是一个Student类,我们想要获知张三的身高体重学号等信息,就采用ZhangSan.height,ZhanSan.weight,ZhangSan.number即可获得,并且可以看到这里的属性是不带括号的
在Python中,一切数据对象都是一个类,包括ndarray类型。shape就是ndarray数据的一个属性,shape表示这个ndarray实例的形状,即各维度的数值。dtype也是其属性之一,即datatype得缩写,表示这个ndarray实例的数据类型。
需要注意的就是注意属性不加括号!!!使用方法如下:

In [16]: b = np.array([[1,2,3],[4,5,6],[7,8,9]])                                

In [17]: b                                                                      
Out[17]: 
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [18]: b.shape                                                                
Out[18]: (3, 3)

In [19]: b.shape[0]                                                             
Out[19]: 3

In [20]: c = [1, 2, 3]        # c不是ndarray类型                                                  

In [21]: c.shape                                                                
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-21-d6049491b182> in <module>
----> 1 c.shape

AttributeError: 'list' object has no attribute 'shape'

从类的角度看,把ndarray和numpy都当做Python的一个类,ndarray.shape表示ndarray的属性,自然可知,np.shape()其实就是numpy类的方法。
在numpy中,一般可直接用于ndarray类型数据上的方法也有与之对应的numpy函数可执行相同操作,如:

In [52]: a = np.arange(5)                                                       

In [53]: a                                                                      
Out[53]: array([0, 1, 2, 3, 4])

In [54]: np.sum(a)                                                              
Out[54]: 10

In [55]: a.sum()                                                                
Out[55]: 10

记住,函数或方法很像,都要带括号!!!属性不带括号!!!

In [56]: a = np.random.randn(5,3)                                               

In [57]: a                                                                      
Out[57]: 
array([[-0.47169257, -1.33625595,  1.09450799],
       [ 0.68097098, -0.77349608, -0.13462524],
       [ 1.01122524, -0.72573122, -2.80145914],
       [ 0.32187105,  0.66012558, -0.80316889],
       [-0.79434656,  0.33565231, -0.51083857]])

In [58]: a.shape              #获取矩阵大小                                                  
Out[58]: (5, 3)

In [59]: a.ndim                #获取矩阵维度                                                
Out[59]: 2

In [60]: a.dtype               #获取矩阵数据类型                                                 
Out[60]: dtype('float64')

(3) Tensor.get_shape().as_list()

这是tensorflow中常用于获取tensor维度信息的函数,注意该函数只能用于tensor对象。Tensor.get_shape()本身获取tensor的维度信息并以元组的形式返回,由于元组内容不可更改,故该函数常常跟.as_list()连用,返回一个tensor维度信息的列表,以供后续操作使用。

上一篇下一篇

猜你喜欢

热点阅读