Matplotlib使用介绍
Matplotlib 最重要的特性之一就是具有良好的操作系统兼容性和图形显示底层接口兼容性。虽然,近几年Matplotlib 的界面与风格似乎有点跟不上时代,但仍然是数据可视化技术不可缺少的一环。
1. 最常用的基本知识
类似numpy和pandas,matplotlib也有自己约定俗成的引入方式:
In [20]: import matplotlib as mpl
In [22]: import matplotlib.pyplot as plt
通过plt.style来选择图形的绘图风格。
In [23]: plt.style.use('classic')
Matplotlib包含有3种实验环境:分别是脚本、IPython shell 和 IPython Notebook。如果使用脚本,为了显示图片就必须使用plt.show()语句,plt.show() 会启动一个事件循环(eventloop),并找到所有当前可用的图形对象,然后打开一个或多个交互式窗口显示图形。需要注意,一个 Python 会话(session)中只能使用一次 plt.show(),因此通常都把它放在脚本的最后。
在IPython中启动 Matplotlib 模式就可以使用,使用语句%matplotlib,此后的任何 plt 命令都会自动打开一个图形窗口,增加新的命令,图形就会更新(图形是交互式变化的)。对于这些变化,可以使用 plt.draw() 强制更新。不需要使用plt.show()了。
使用IPython Notebook与IPython类似:
%matplotlib notebook 会在 Notebook 中启动交互式图形。
%matplotlib inline 会在 Notebook 中启动静态图形。(经常使用)
In [3]: %matplotlib
Using matplotlib backend: MacOSX
In [4]: x = np.linspace(0,10,100)
In [6]: fig = plt.figure()
In [7]: plt.plot(x,np.sin(x),'-')
Out[7]: [<matplotlib.lines.Line2D at 0x12371a490>]
In [8]: plt.plot(x,np.cos(x),'--')
Out[8]: [<matplotlib.lines.Line2D at 0x10f82e110>]
In [16]: fig.savefig('first_pic.png') #保存为PNG格式的文件
In [17]: fig.canvas.get_supported_filetypes()
Out[17]:
{u'eps': u'Encapsulated Postscript',
u'jpeg': u'Joint Photographic Experts Group',
u'jpg': u'Joint Photographic Experts Group',
u'pdf': u'Portable Document Format',
u'pgf': u'PGF code for LaTeX',
u'png': u'Portable Network Graphics',
u'ps': u'Postscript',
u'raw': u'Raw RGBA bitmap',
u'rgba': u'Raw RGBA bitmap',
u'svg': u'Scalable Vector Graphics',
u'svgz': u'Scalable Vector Graphics',
u'tif': u'Tagged Image File Format',
u'tiff': u'Tagged Image File Format'}
得到下边的图形:
Matplotlib 的一个优点是能够将图形保存为各种不同的数据格式。你可以用 savefig() 命令将图形保存为文件。不同格式的文件是以文件后缀名来区分的,上边的命令列出了matplotlib支持的所有格式名。
2. 两种画图接口
Matplotlib 有一个容易让人混淆的特性,就是它的两种画图接口:
(1) 便捷的 MATLAB 风格接口。
(2)功能更强大的面向对象接口。
2.1. MATLAB风格
MATLAB 风格的工具位于 pyplot(plt)接口中,如下边一个例子:
x = np.linspace(0,10,100)
fig = plt.figure()
plt.subplot(2,1,1) # 行、列、编号 (图形矩阵)
plt.plot(x,np.sin(x),'-')
plt.subplot(2,1,2)
plt.plot(x,np.cos(x),'--')
这种接口最重要的特性是有状态的(stateful)。可以用plt.gcf()(获取当前图形)和 plt.gca()(获取当前坐标轴)来查看。 这种图画起图来又快又方便,但是也很容易出问题。例如,在画第2个图时,已经无法修改第1个图了。
2.2. 面向对象风格
面向对象接口可以处理更加复杂的场景,画图函数不再受到当前“活动”图形或坐标轴的限制,而变成了显式的 Figure 和 Axes 的方法。
fig,ax = plt.subplots(2)
ax[0].plot(x,np.sin(x))
ax[1].plot(x,np.cos(x))
3. 线形图
要画 Matplotlib 图形时,都需要先创建一个图形 fig 和一个坐标轴ax。
最简单地,通过下边方法创建这2个对象:
fig = plt.figure()
ax = plt.axes()
在 Matplotlib 里面,figure(plt.Figure 类的一个实例)可以被看成是一个能够容纳各种坐标轴、图形、文字和标签的容器。axes(plt.Axes 类的一个实例)是一个带有刻度和标签的矩形,最终会包含所有可视化的图形元素。
fig = plt.figure()
ax = plt.axes()
ax.plot(x,np.sin(x),color='blue',linestyle = '--',label='sin(x)')
ax.plot(x,np.sin(x-1),color='olive',linestyle ='-.',label='sin(x-1)')
ax.plot(x,np.sin(x-2),'-g',label='sin(x-2)') #线条风格+颜色简写形式
ax.set(xlim=(-2,11),ylim=(-2,2),xlabel='x',ylabel='y',title='A simple plot')
ax.legend() # 显示图例
image.png
如果需要在一个fig画多个线,直接plot多次即可。ax对象可以通过set()方法来统一设置很多属性。
matplotlib 可选的线条风格和颜色:
https://blog.csdn.net/detaswc/article/details/81086757
最后,通过一张表格来总结下本节介绍过的一些常用的属性(更详细可以参考官方文档):
name | 说明 | 应用层级 |
---|---|---|
color | 配置线条颜色 | 曲线 |
linestyle | 线条风格(直线、虚线等) | 曲线 |
label | 曲线标签(图例) | 曲线 |
xlim | 图形x轴取值范围 | 整个图表 |
ylim | 图形y轴取值范围 | 整个图表 |
xlabel | 图形x轴标签名 | 整个图表 |
ylabel | 图形y轴标签名 | 整个图表 |
title | 图形的标题 | 整个图表 |
legend() | 生成图例的方法 | 方法 |
4. 散点图
散点图(scatter plot)与线形图类似,不再由线段连接,而是由独立的点、圆圈或其他形状构成。
4.1 使用plt.plot画散点图
第一种方式与上边线形图的构造方法十分类似,只修改下linestyle参数即可(之前是- 修改为o),执行效果如下:
plt.plot(x,np.sin(x),'o',color='black')
4.2 使用plt.scatter画散点图
plt.scatter 与 plt.plot 的主要差别在于,前者在创建散点图时具有更高的灵活性,可以单独控制每个散点与数据匹配,也可以让每个散点具有不同的属性(大小、表面颜色、边框颜色等)。
当数据变大到几千个散点时,plt.plot 的效率将大大高于plt.scatter。因此面对大型数据集时,plt.plot 方法比 plt.scatter 方法好。
下面以sklearn的鸢尾花数据为例制作一个散点图:
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
了解特征数据的说明:
from sklearn.datasets import load_iris
iris = load_iris()
iris.keys()
iris.DESCR
iris.feature_names
作图代码:
from sklearn.datasets import load_iris
iris = load_iris()
features = iris.data.T
feature_names = iris.feature_names
fig = plt.plot()
ax = plt.axes()
ax.scatter(features[0],features[1],alpha=0.2,s=100*features[3], c=iris.target, cmap='viridis')
ax.set(xlabel=feature_names[0] , ylabel=feature_names[1],title='First Scatter Plot')
Scatter参数详见:https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.scatter.html#matplotlib.axes.Axes.scatter
5. 误差线图
基本误差线(errorbar)可以通过一个 Matplotlib 函数来创建。
其中yerr指定是在纵轴方向误差,xerr是在横轴方向误差,当然可以两方向同时误差。
下面是个简单的例子:
fig = plt.figure()
ax = plt.axes()
dy = 0.8
x_e = np.linspace(0,10,50)
y_e = np.sin(x_e) + dy * np.random.randn(50)
ax.errorbar(x_e,y_e,yerr=dy,fmt='o',color='black',ecolor='lightgray',
elinewidth=3,capsize=0)
errorbar参数详见:https://matplotlib.org/api/_as_gen/matplotlib.axes.Axes.errorbar.html#matplotlib.axes.Axes.errorbar
6. 密度图与等高线图
在二维图上用等高线图或者彩色图来表示三维数据是个不错的方法。用 plt.contour 画等高线图、用 plt.contourf 画带有填充色的等高线图(filled contourplot)的色彩、用 plt.imshow 显示图形。
等高线图可以用 plt.contour 函数来创建。它需要三个参数:x 轴、y轴、z 轴三个坐标轴的网格数据。x 轴与 y 轴表示图形中的位置,而 z 轴将通过等高线的等级来表示。
np.meshgrid 函数可以从一维数组构建二维网格数据。
理解np.meshgrid :
实际上就是 生成网格点坐标矩阵
我以一个简单的例子展开说明:
t_ga = np.array([0,4,5,8,9])
t_gb = np.array([1,3,9,12])
XT,YT = np.meshgrid(t_ga,t_gb)
plt.plot(XT,YT,color='green',linestyle='',marker='o')
plt.show()
以上代码得到图形如下:
实际上,以提供的两个数组,在X轴和Y轴做垂线,这些所有线的交点也成了所谓的网格矩阵了。
参考文章:
https://blog.csdn.net/lllxxq141592654/article/details/81532855
介绍完meshgrid概念后,看看最简单的等高图:
def f(x, y):
return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
x_f = np.linspace(0,5,50)
y_f = np.linspace(0,5,40)
X,Y = np.meshgrid(x_f,y_f)
Z = f(X,Y)
fig = plt.figure()
ax = plt.axes()
ax.contour(X,Y,Z,colors='black')
当图形中只使用一种颜色时,默认使用虚线表示负数,使用实线表示正数。此图确实比较难看,下边提供了一系列改进方式。
改进1:
你可以用 cmap 参数设置一个线条配色方案来自定义颜色。下边语句将数据范围等分为 20 份,然后用不同的颜色表示。Matplotlib 有非常丰富的配色方案,你可以在IPython 里用 Tab 键浏览 plt.cm 模块对应的信息。plt.cm.<TAB>
ax.contour(X,Y,Z,30,cmap='RdGy')
改进2:
上边图形的线条间的间隙很大,可以通过plt.contourf来改进,其语法与contour完全一致。另外,通过 plt.colorbar() 命令自动创建一个表示图形各种颜色对应标签信息的颜色条。
ax0 = ax.contourf(X,Y,Z,30,cmap='RdGy')
fig.colorbar(ax0,ax=ax)
改进3:
上边的图形的主要问题是颜色的改变是一个离散而非连续的过程,这导致了图形看起来有点儿“污渍斑斑”。plt.imshow() 函数可以将二维数组渲染成渐变图。
contours = plt.contour(X, Y, Z, 3, colors='black')
plt.clabel(contours, inline=True, fontsize=8)
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower',
cmap='RdGy', alpha=0.5)
plt.colorbar();
7. 频次直方图、数据区间划分和分布密度
10.多子图
当需要做对比分析时,有必要画多个图。
10.1 plt.axes()创建子图
plt.axes()函数有4个参数:[bottom, left, width, height](底坐标<x轴>、左坐标<y轴>、宽度、高度),数值的取值范围是左下角(原点)为 0,右上角为 1。以上参数为了生成“画中画”,几个参数都是百分比。
ax_sub_3 = plt.axes()
ax_sub_4 = plt.axes([0.4,0.6,0.5,0.3])
plt.show()
面向对象画图接口中类似的命令是:fig.add_axes(),例如下边的代码可以生成上下两个图:
fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4],xticklabels=[],
ylim=(-1.2, 1.2)) # xticklabels=[],表示上边的坐标轴无刻度。
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4],ylim=(-1.2, 1.2))
x = np.linspace(0, 10)
ax1.plot(np.sin(x))
ax2.plot(np.cos(x))
10.2 plt.subplot()创建子图
plt.subplot() 在一个网格中创建一个子图。
plt.subplots_adjust 命令可以调整子图之间的间隔。
使用MATLAB接口创建子图使用:plt.subplot(2,3,i)
使用面向对象接口创建子图使用:fig.add_subplot(2,3,i)
上边例子中,subplot,前2个参数是子图矩阵的行和列(2行3列子图,共6个),最后一个i是子图的编号(从1到6)。下边是2个例子:
for i in range(1,7):
plt.subplot(2,3,i)
plt.text(0.5,0.5,str((2,3,i)),fontsize=18,ha='center')
plt.subplots_adjust(hspace=0.4,wspace=0.4) #间隙的高和宽分别是子图的40%
fig = plt.figure()
fig.subplots_adjust(hspace=0.4,wspace=0.4)
for i in range(1,7):
axi=fig.add_subplot(2,3,i)
axi.text(0.5,0.5,str((2,3,i)),fontsize=16,ha='center')
10.3 plt.subplots()
本节介绍的方法跟上边相比多个s,如果需要隐藏子图的x轴和y轴,可以使用 plt.subplots()。该方法返回子图的Numpy数组。
plt.subplots()的参数包括:子图矩阵的行数和列数,可选参数sharex和sharey,参考下边例子:
fig,ax = plt.subplots(3,2,sharex='col' , sharey='row') #列共享x轴(sharex),行共享y轴
for i in range(0,3):
for j in range(0,2):
ax[i][j].text(0.5,0.5,str((i,j)),fontsize=16,ha='center')
fig.subplots_adjust(hspace=0.4,wspace=0.4)
10.4 plt.GridSpec()实现更复杂的排列方式
上边10.2和10.3节画的都是规则的多行多列子图网格,如果想画不规则的,可以使用plt.GridSpec()。plt.GridSpec()并不是规则图形,而是plt.subplot()的简易接口。
例如,下边的例子,生成了不规则的图形:
grid = plt.GridSpec(3,3)
plt.subplot(grid[0,:2])
plt.subplot(grid[0,2])
plt.subplot(grid[1:,:2])
plt.subplot(grid[1:,2])
下面是plt.GridSpec()的一个应用,多轴频次直方图:
11.文字和注释
在图表中增加文字标签,主要通过ax.text方法。
ax.text 方法需要一个 x 轴坐标、一个 y 轴坐标、一个字符串和一些可选参数,比如文字的颜色、字号、风格、对齐方式以及其他文字属性。
下边是一个例子:
births=pd.read_csv('python_science_handbook_data/births.csv')
quartiles = np.percentile(births['births'],[25,50,75]) #25,50,75分位数
mid,sig = quartiles[1],0.74*(quartiles[2]-quartiles[0])
# (mid - 5*sig,mid + 5*sig)
births = births.query('(births > @mid - 5*@sig) & (births < @mid + 5*@sig)')
births['day'] = births['day'].astype(int) #day列转化为int
# births['unix_time'] = 10000*births.year + 100*births.month + births.day ##19690101
births.index = pd.to_datetime(10000*births.year + 100*births.month + births.day,format='%Y%m%d')
births_by_date = births.pivot_table('births',[births.index.month, births.index.day])
births_by_date.index = [pd.datetime(2012, month, day)for (month, day) in births_by_date.index]
fig, ax = plt.subplots(figsize=(12, 4))
# births_by_date.loc['2012-1-1':'2012-1-1']
births_by_date.plot(ax=ax);
# 在图上增加文字标签
style = dict(size=10, color='gray')
ax.text('2012-1-1', 3950, "New Year's Day", **style)
ax.text('2012-7-4', 4250, "Independence Day", ha='center', **style)
ax.text('2012-9-4', 4850, "Labor Day", ha='center', **style)
ax.text('2012-10-31', 4600, "Halloween", ha='right', **style)
ax.text('2012-11-25', 4450, "Thanksgiving", ha='center', **style)
ax.text('2012-12-25', 3850, "Christmas ", ha='right', **style)
ax.set(title='USA births by day of year (1969-1988)',ylabel='average daily births')
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));