一、训练模型

构建模型(假设为model)后,接下来就是训练模型。PyTorch训练模型主要包括加载数据集、损失计算、定义优化算法、反向传播、参数更新等主要步骤。

1.加载预处理数据集

加载和预处理数据集,可以使用PyTorch的数据处理工具,如torch.utils和torchvision等,这些工具将在第4章详细介绍。

2.定义损失函数

定义损失函数可以通过自定义方法或使用PyTorch内置的损失函数,如回归使用的losss_fun=mn.MSELoss(),分类使用的nn.BCELoss等损失函数,更多内容可参考本书5.2.4节。

3.定义优化方法

Pytoch常用的优化方法都封装在torch.optim里面,其设计很灵活,可以扩展为自定义的优化方法。所有的优化方法都是继承了基类optim.Optimizer,并实现了自己的优化步骤。

最常用的优化算法就是梯度下降法及其各种变种,具体将在5.4节详细介绍,这些优化算法大多使用梯度更新参数。

如使用SGD优化器时,可设置为optimizer=torch.optim.SGD(params,lr=0.001)。

4.循环训练模型

1)设置为训练模式:

model. train()

调用model.train()会把所有的module设置为训练模式。

2)梯度清零:

optimizer. zero_grad()

在默认情况下梯度是累加的,需要手工把梯度初始化或清零,调用optimizer.zero_grad()即可。

3)求损失值:

y_prev=model(x)

loss=loss_fun(y_prev,y_true)

4)自动求导,实现梯度的反向传播:

loss. backward()

5)更新参数:

optimizer.step()

5.循环测试或验证模型

1)设置为测试或验证模式:

model.eval()

调用model.eval()会把所有的training属性设置为False。

2)在不跟踪梯度模式下计算损失值、预测值等:

with. torch. no_grad():

6.可视化结果

下面我们通过实例来说明如何使用mm来构建网络模型、训练模型。

【说明】model.train()与model.eval()的使用

如果模型中有BN(Batch Normalization)层和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout, model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。

二、实现神经网络实例

前面我们介绍了使用PyTorch构建神经网络的一些组件、常用方法和主要步骤等,本节通过一个构建神经网络的实例把这些内容有机结合起来。

1、背景说明

本节将利用神经网络完成对手写数字进行识别的实例,来说明如何借助nn工具箱来实现一个神经网络,并对神经网络有个直观了解。在这个基础上,后续我们将对nn的各模块进行详细介绍。实例环境使用PyTorch1.5+,GPU或CPU,源数据集为MNIST。

主要步骤如下。

利用PyTorch内置函数mnist下载数据。

利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器。

可视化源数据。

利用nn工具箱构建神经网络模型。

实例化模型,并定义损失函数及优化器。

训练模型。

可视化结果。

神经网络的结构如图3-5所示。

使用两个隐含层,每层使用ReLU激活函数,输出层使用softmax激活函数,最后使用torch.max(out,1)找出张量out最大值对应索引作为预测值。

2、准备数据

import numpy as np

import torch

# 导入 pytorch 内置的 mnist 数据

from torchvision.datasets import mnist

#import torchvision

import torchvision.transforms as transforms

from torch.utils.data import Dataloader

#导入nn及优化器

import torch.nn.functional as F

import torch.optim as optim

from torch import nn

 

from torch.utils.tensorboard import SummaryWriter

 

# 定义一些超参数

train_batch_size = 64

test_batch_size = 128

learning_rate = 0.01

num_epoches = 20

 

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])

 

#定义预处理函数

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])

#下载数据,并对数据进行预处理

train_dataset = mnist.MNIST('../data/', train-True, transform=transform, download-True)

test_dataset = mnist.MNIST('../data/', train=False, transform=transform)

#得到一个生成器

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

examples = enumerate(test_loader)

batch_idx, (example_data, example_targets) = next(examples)

 

example_data.shape

3、可视化源数据

import matplotlib.pyplot as plt

%matplotlib inline

 

examples = enumerate(test_loader)

batch_idx,(example_data, example_targets) = next(examples)

 

fig = plt.figure()

for i in range(6):

 plt.subplot(2,3,i+1)

 plt.tight_layout()

 plt.imshow(example_data[i][0], cmap='gray', interpolation='none')

 plt.title("Ground Truth: {}".format(example_targets[i]))

 plt.xticks([])

 plt.yticks([])

4、构建模型

class Net(nn.Module):

    """

    使用sequential构建网络,Sequential()函微的功能是将网络的层组合到一起

    """

def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):

    super(Net, self).__init__()

    self.flatten = nn.Flatten()

    self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNormld(n_hidden_1))

    self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNormld(n_hidden_2))

    self.out = nn.Sequential(nn.Linear(n_hidden_2, out_dim))

 

    

def forward(self, x):

    x=self.flatten(x)

    x = F.relu(self.layer1(x))

    x = F.relu(self.layer2(x))

    x = F.softmax(self.out(x),dim=1)

return x

 

lr =0.01

momentum = 0.9

 

#实例化模型

device = torch.device("cuda:0" if torch.cuda.is_available() else"cpu")

model = Net(28 * 28, 300, 100, 10)

model.to(device)

 

#定义损失函数和优化器

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

5、训练模型

#开始训练

losses = []

acces = []

eval_losses = []

eval_acces = []

writer = SummaryWriter(log_dir='logs',comment='train-loss')

 

for epoch in range(num_epoches):

    train_loss = 0

    train_acc = 0

    model.train()

    #动态修改参数学习率

    if epoch%5==0:

        optimizer.param_groups[0]['lr']*=0.9

        print("学习率:{:.6f}".format(optimizer.param_groups[0]['lr']))

    for img, label in train_loader:

        img=img.to(device)

        label = label.to(device)

        #正向传播

        out = model(img)

        loss = criterion(out, label)

        #反向传播

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        #记录误差

        train_loss += loss.item()

        #保存loss的数据与epoch数值

        writer.add_scalar('Train', train_loss/len(train_loader), epoch)

        #计算分类的准确率

        _, pred = out.max(1)

        num_correct = (pred == label).sum().item()

        acc = num_correct / img.shape[0]

        train_acc += acc

 

losses.append(train_loss / len(train_loader))

acces.append(train_acc / len(train_loader))

# 在测试集上检验救果

eval_loss = 0

eval_acc = 0

#net.eval()#将模型改为预测模式

model.eval()

for img, label in test_loader:

    img=img.to(device)

    label = label.to(device)

    img = img.view(img.size(0),-1)

    out = model(img)

    loss = criterion(out, label)

    # 记录误差

    eval_loss += loss.item()

    # 记录准确率

    _, pred = out.max(1)

    num_correct = (pred = label).sum().item()

    acc = num_correct / img.shape[0]

    eval_acc += acc

 

eval_losses.append(eval_loss / len(test_loader))

eval_acces.append(eval_acc / len(test_loader))

print('epoch: {}, Train Loss:{:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f})

    .format(epoch, train_loss / len(train_loader), train_acc / len(train_loader),

                eval_loss / len(test_loader), eval_acc / len(test_loader)))

 

plt.title('train loss')

plt.plot(np.arange(len(losses)), losses)

plt.legend(['Train Loss'], loc='upper right')

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐