数据可视化:pandas透视图、seaborn热力图
2020-02-22 本文已影响0人
洗洗睡吧i
1. 创建需要展示的数据
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
# === define paras ==================
para_names = ['layer_n', 'activition', 'seed']
layer_n = [1, 2, 3, 4, 5, 6]
activition = ['tanh', 'sigmod', 'relu']
seed = [11, 17, 19]
# 创建 dataframe
df = pd.DataFrame([], columns=para_names)
for values in itertools.product(layer_n, activition, seed):
newline = pd.DataFrame(list(values), index=para_names)
df = df.append(newline.T, ignore_index=True)
# 伪造一些训练结果,方便展示
# activ_2_num = pd.factorize(df['activition'])[0].astype('int') # 激活函数是字符类型,将其映射成整数形
activ_dict = {'tanh': 2, 'sigmod': 4, 'relu': 6} # 也可以直接定义字典,然后replace
df['results'] = df['layer_n'] + df['activition'].replace(activ_dict) + df['seed'] * 0.1 + np.random.random((54,))
df['results'] = df['results'].astype('float') # 转换成浮点类型
print(df.head())
输出:
layer_n activition seed results
0 1 tanh 11 4.261361
1 1 tanh 17 4.822595
2 1 tanh 19 4.929088
3 1 sigmod 11 6.698047
4 1 sigmod 17 7.020531
2. 绘制带误差的折线图展示训练结果
# 绘制带误差的折线图,横轴为网络层数,纵轴为训练结果,
# 激活函数采用不同颜色的线型,误差来自于没有指定的列:不同的随机种子seed
plt.figure(figsize=(8, 6))
sns.lineplot(x='layer_n', y='results', hue='activition', style='activition',
markers=True, data=df)
plt.grid(linestyle=':')
plt.show()
![](https://img.haomeiwen.com/i206717/a4e519b5d15e2d31.png)
3. 使用pandas透视图、seaborn热力图来展示
# 创建透视图,
# 对于没有指定的列(seed),按最大值进行统计
dt = pd.pivot_table(df, index=['layer_n'], columns=['activition'], values=['results'], aggfunc=[max])
print(dt)
print(dt.columns)
# 找到最大值、最大值所对应的索引
max_value, max_idx = dt.stack().max(), dt.stack().idxmax()
print(f' - the max value is {max_value};\n - the index is {max_idx}...')
# 透视图变成了多重索引(MultiIndex),重新调整一下
new_col = dt.columns.levels[2]
dt.columns = new_col
# dt.index = list(dt.index)
print(dt)
dt.sort_index(axis=0, ascending=False, inplace=True) # 必要时将索引重新排序
dt.sort_index(axis=1, ascending=False, inplace=True) # 必要时将索引重新排序
# 绘制热力图,横轴为网络层数,纵轴为激活函数,
# 栅格的颜色代表训练结果,颜色越深结果越好
plt.figure(figsize=(8, 6))
g = sns.heatmap(dt, vmin=0.0, annot=True, fmt='.2g', cmap='Blues', cbar=True)
plt.show()
![](https://img.haomeiwen.com/i206717/4da32bc84dabe053.png)
ref:
-
pandas.pivot_table https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.pivot_table.html
-
seaborn.lineplot https://seaborn.pydata.org/generated/seaborn.lineplot.html
-
seaborn.heatplot https://seaborn.pydata.org/generated/seaborn.heatmap.html