数据可视化呆鸟的Python数据分析数据可视化分析

数据可视化:pandas透视图、seaborn热力图

2020-02-22  本文已影响0人  洗洗睡吧i

1. 创建需要展示的数据

import itertools

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# === define paras ==================
para_names = ['layer_n', 'activition', 'seed']

layer_n = [1, 2, 3, 4, 5, 6]
activition = ['tanh', 'sigmod', 'relu']
seed = [11, 17, 19]

# 创建 dataframe
df =  pd.DataFrame([], columns=para_names)
for values in itertools.product(layer_n, activition, seed):
    newline = pd.DataFrame(list(values), index=para_names)
    df = df.append(newline.T, ignore_index=True)

# 伪造一些训练结果,方便展示
# activ_2_num = pd.factorize(df['activition'])[0].astype('int')  # 激活函数是字符类型,将其映射成整数形
activ_dict = {'tanh': 2, 'sigmod': 4, 'relu': 6}  # 也可以直接定义字典,然后replace
df['results'] = df['layer_n'] + df['activition'].replace(activ_dict) + df['seed'] * 0.1 + np.random.random((54,))
df['results'] = df['results'].astype('float')  # 转换成浮点类型
print(df.head())

输出:

  layer_n activition seed   results
0       1       tanh   11  4.261361
1       1       tanh   17  4.822595
2       1       tanh   19  4.929088
3       1     sigmod   11  6.698047
4       1     sigmod   17  7.020531

2. 绘制带误差的折线图展示训练结果

# 绘制带误差的折线图,横轴为网络层数,纵轴为训练结果,
# 激活函数采用不同颜色的线型,误差来自于没有指定的列:不同的随机种子seed
plt.figure(figsize=(8, 6))
sns.lineplot(x='layer_n', y='results', hue='activition',  style='activition', 
             markers=True, data=df)
plt.grid(linestyle=':')
plt.show()

3. 使用pandas透视图、seaborn热力图来展示

# 创建透视图,
# 对于没有指定的列(seed),按最大值进行统计
dt = pd.pivot_table(df, index=['layer_n'], columns=['activition'], values=['results'], aggfunc=[max])
print(dt)
print(dt.columns)  

# 找到最大值、最大值所对应的索引
max_value, max_idx = dt.stack().max(), dt.stack().idxmax()
print(f' - the max value is {max_value};\n - the index is {max_idx}...')

# 透视图变成了多重索引(MultiIndex),重新调整一下
new_col = dt.columns.levels[2]
dt.columns = new_col
# dt.index = list(dt.index)
print(dt)

dt.sort_index(axis=0, ascending=False, inplace=True)  # 必要时将索引重新排序
dt.sort_index(axis=1, ascending=False, inplace=True)  # 必要时将索引重新排序

# 绘制热力图,横轴为网络层数,纵轴为激活函数,
# 栅格的颜色代表训练结果,颜色越深结果越好
plt.figure(figsize=(8, 6))
g = sns.heatmap(dt, vmin=0.0, annot=True, fmt='.2g', cmap='Blues', cbar=True)
plt.show()

ref:

上一篇 下一篇

猜你喜欢

热点阅读