10钢材缺陷检测分类

2023-02-24  本文已影响0人  Jachin111

数据预处理

# 导入数据
import pandas as pd
import numpy as np

import plotly_express as px
import plotly.graph_objects as go

from plotly.subplots import make_subplots

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")
df = pd.read_excel("faults.xlsx")
df.head()
image.png
# 数据分割
df1 = df.loc[:,"Pastry":]
df2 = df.loc[:,:"SigmoidOfAreas"]
df1.head()
image.png
# 下面是27个特征的数据
df2.head()
image.png
# 分类标签生成
columns = df1.columns.tolist()
columns
image.png
for i in range(len(df1)):
    for col in columns:
        if df1.loc[i, col] == 1:
            df1.loc[i, "Label"] = col

df1.head()
image.png
# 类型编码
dic = {}
for i, v in enumerate(columns):
    dic[v] = i
    
dic
image.png
df1["Label"] = df1["Label"].map(dic)
df1.head()
image.png
# 数据合并
df2["Label"] = df1["Label"]
df2.head()
image.png

EDA

# 数据的基本统计信息
# 缺失值
df2.isnull().sum()
image.png
# 单个特征分布
parameters = df2.columns[:-1].tolist()

sns.boxplot(data=df2, y="Steel_Plate_Thickness")
plt.show()
image.png

从箱型图中能够观察到单个特征的取值分布情况。下面绘制全部参数的取值分布箱型图

fig = make_subplots(rows=7, cols=4)

for i, v in enumerate(parameters):
    r = i // 4 + 1
    c = (i + 1) % 4
    
    if c == 0:
        fig.add_trace(go.Box(y=df2[v].tolist(), name=v), row=r, col=4)
    else:
        fig.add_trace(go.Box(y=df2[v].tolist(), name=v), row=r, col=c)
        
fig.update_layout(width=1000, height=900)
fig.show()
image.png

1.特征之间的取值范围不同,从负数到10M
2.部分特征的取值中存在异常值
3.有些特征的取值只存在0和1

样本不均衡

# 每种类别数量
df2["Label"].value_counts()
image.png

可以看到第6类的样本有673条,但是第4类的样本只有55条。明显地不均衡

# SMOTE解决
X = df2.drop("Label", axis=1)
y = df2[["Label"]]
from imblearn.over_sampling import SMOTE

smo = SMOTE(random_state=42)
X_smo, y_smo = smo.fit_resample(X, y)
y_smo
image.png
# 实施上采样后的结果
y_smo["Label"].value_counts()
image.png

统计一下每个分类变量的数量:现在我们发现每个类别下的样本是一样的,克服了样本不均衡问题

建模

# 随机打乱数据
from sklearn.utils import shuffle

df3 = pd.concat([X_smo, y_smo], axis=1)
df3 = shuffle(df3)
# 数据集划分
from sklearn import preprocessing

X = df3.drop("Label", axis=1)
X = preprocessing.scale(X)
y = df3[["Label"]]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=4)
# 建模评价
from sklearn.model_selection import cross_val_score
from sklearn import metrics

def build_model(model, X_test, y_test):
    model.fit(X_train, y_train)
    # 预测概率
    y_proba = model_LR.predict_proba(X_test)
    # 找出概率值最大的所在索引,作为预测的分类结果
    y_pred = np.argmax(y_proba, axis=1)
    y_test = np.array(y_test).reshape(943)
    
    print(f"{model}模型得分:")
    print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
    print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
# 逻辑回归
from sklearn.linear_model import LogisticRegression

model_LR = LogisticRegression()
build_model(model_LR, X_test, y_test)
image.png

逻辑回归

# 建模
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from sklearn import metrics

model_LR = LogisticRegression()
model_LR.fit(X_train, y_train)
image.png
# 预测概率
y_proba = model_LR.predict_proba(X_test)
y_proba[:3]
image.png
# 找出概率值最大的所在索引,作为预测的分类结果
y_pred = np.argmax(y_proba, axis=1)
y_pred[:3]
image.png
# 评价
# 混淆矩阵
confusion_matrix = metrics.confusion_matrix(y_test, y_pred)
confusion_matrix
image.png
y_pred.shape
image.png
y_test = np.array(y_test).reshape(943)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

随机森林回归

from sklearn.ensemble import RandomForestClassifier

model_RR = RandomForestClassifier()
model_RR.fit(X_train, y_train)
image.png
# 预测概率
y_proba = model_RR.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

SVR

from sklearn.svm import SVC

svm_model = SVC(probability=True)
svm_model.fit(X_train, y_train)
image.png
# 预测概率
y_proba = svm_model.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

决策树回归

from sklearn.tree import DecisionTreeClassifier

model_tree = DecisionTreeClassifier()
model_tree.fit(X_train, y_train)

# 预测概率
y_proba = model_tree.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

神经网络

from sklearn.neural_network import MLPClassifier

mlp = MLPClassifier()
mlp.fit(X_train, y_train)

# 预测概率
y_proba = mlp.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

GBDT

from sklearn.ensemble import GradientBoostingClassifier

gbdt = GradientBoostingClassifier(
#        loss="deviance",
#        learning_rate=1,
#        n_estimators=5,
#        subsample=1,
#        min_samples_split=2,
#        min_samples_leaf=1,
#        max_depth=2,
#        init=None,
#        random_state=None,
#        max_feature=None,
#        verbose=0,
#        max_leaf_nodes=None,
#        warm_start=False
)
gbdt.fit(X_train, y_train)

# 预测概率
y_proba = gbdt.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png

LightGBM

from lightgbm import LGBMClassifier

lgb = LGBMClassifier()
lgb.fit(X_train, y_train)

# 预测概率
y_proba = lgb.predict_proba(X_test)
# 最大概率的索引
y_pred = np.argmax(y_proba, axis=1)

print("召回率: ", metrics.recall_score(y_test, y_pred, average="macro"))
print("精准率: ", metrics.precision_score(y_test, y_pred, average="macro"))
image.png
上一篇下一篇

猜你喜欢

热点阅读