Python

逻辑回归从零实现以及PyTorch实现

2021-09-11  本文已影响0人  酷酷的群

逻辑回归原理参考链接:线性分类|机器学习推导系列(四)

一、逻辑回归从零开始实现

1. 导入所需要的库

import numpy as np
import matplotlib.pyplot as plt

2. 人工构造数据集

构造一个比较简单的二分类数据集,满足高斯分布,并进行了数据的可视化。

def create_data(size):
    X0=np.random.normal(2,1,(size,2))
    y0=np.zeros(size)
    X1=np.random.normal(-2,1,(size,2))
    y1=np.ones(size)
    X=np.concatenate((X0,X1),axis=0)
    y=np.concatenate((y0,y1),axis=0)
    return X,y

X,y=create_data(1000)
X_test,y_test=create_data(100)


# 可视化数据
plt.scatter(X[:,0],X[:,1],c=y,s=40,lw=0, cmap='RdYlGn')
plt.show()

# 添加偏置
X=np.insert(X,0,1,axis=1)
y=y.reshape(y.shape[0],-1)
X_test=np.insert(X_test,0,1,axis=1)
y_test=y_test.reshape(y_test.shape[0],-1)

3. 初始化模型参数

n_samples,n_features=np.shape(X)
limit=np.sqrt(n_features)
w=np.random.uniform(-limit,limit,(n_features,1)) # 学习参数
lr=.1 # 学习率
iters=1001 # 迭代次数

4. 定义sigmoid函数

def sigmoid(x):
    return 1/(1+np.exp(-x))

5. 定义优化算法

def gradient_descent(X,y,w,lr):
    y_pred=sigmoid(X@w)
    gradient=-((y-y_pred)*X).sum(axis=0)
    gradient=gradient.reshape(gradient.shape[0],-1)
    w=w-lr*gradient
    return w

6. 定义损失函数

def loss(X,y,w):
    y_pred=sigmoid(X@w)
    delta=1e-7
    left_log=np.log(y_pred+delta)
    right_log=np.log(1-y_pred+delta)
    l=-(y*left_log+(1-y)*right_log).sum()
    return l

7. 训练模型

定义测试函数。

def test(X_test,y_test):
    y_test_pred=sigmoid(X_test@w)
    y_test_pred[y_test_pred>=0.5]=1.0
    y_test_pred[y_test_pred<0.5]=0.0
    correct_list=[]
    for i,pred in enumerate(y_test_pred):
        is_correct = 1 if pred==y_test[i] else 0
        correct_list.append(is_correct)
    acc = sum(correct_list) / len(y_test)
    return acc

定义训练过程,并且进行记录。

iter_num=[]
acc_num=[]
loss_num=[]
for i in range(iters):
    w=gradient_descent(X,y,w,lr)
    if i%50 == 0:
        iter_num.append(i)
        acc_num.append(test(X_test,y_test))
        loss_num.append(loss(X,y,w))

8. 结果可视化

plt.subplot(221)
plt.plot(iter_num,acc_num,color='r',label='acc') 
plt.xlabel('epochs')
plt.ylabel('acc')
plt.title("acc")
plt.legend()
plt.show()

plt.subplot(222)
plt.plot(iter_num,loss_num,color='r',label='loss') 
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title("loss")
plt.legend()
plt.show() 

二、使用torch.nn实现逻辑回归

1. 导入所需要的库

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

2. 人工构造数据集

使用与之前同样的数据。

def create_data(size):
    n_data=torch.ones(size,2)
    X0=torch.normal(2*n_data,1)
    y0=torch.zeros(size)
    X1=torch.normal(-2*n_data,1)
    y1=torch.ones(size)
    
    X=torch.cat((X0,X1),0).type(torch.FloatTensor)
    y=torch.cat((y0,y1),0).type(torch.FloatTensor)
    return X,y
    
X,y=create_data(1000)
X_test,y_test=create_data(100)

plt.scatter(X.data.numpy()[:,0],X.data.numpy()[:,1],c=y.data.numpy(),s=40,lw=0,cmap='RdYlGn')
plt.show()

3. 定义模型

逻辑回归相当于一个两层神经网络,输出层只有一个节点,并且激活函数为sigmoid函数。

class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression,self).__init__()
        self.lr = nn.Linear(2,1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,X):
        X=self.lr(X)
        X=self.sigmoid(X)
        return X
    
model = LogisticRegression()

4. 定义损失函数和优化器

criterion=nn.BCELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=1e-3)

5. 训练模型

iter_num=[]
acc_num=[]
loss_num=[]
for epoch in range(3):
    logit=model(X)
    loss=criterion(logit,y)
    
    y_pred=logit.ge(0.5).float()
    correct=(y_pred==y).sum()
    acc=correct.item()/X.size(0)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1)%1==0:
        iter_num.append(epoch)
        acc_num.append(acc)
        loss_num.append(loss.data.item())
        print('epoch:{}'.format(epoch+1),',','loss:{:.4f}'.format(loss.data.item()),',','acc:{:.4f}'.format(acc))

6. 结果可视化

plt.subplot(221)
plt.plot(iter_num,acc_num,color='r',label='acc') 
plt.xlabel('epochs')
plt.ylabel('acc')
plt.title("acc")
plt.legend()
plt.show()

plt.subplot(222)
plt.plot(iter_num,loss_num,color='r',label='loss') 
plt.xlabel('epochs')
plt.ylabel('loss')
plt.title("loss")
plt.legend()
plt.show() 
上一篇 下一篇

猜你喜欢

热点阅读