1、输入预处理

2、特征提取(CNN)

4、损失函数

3、分类器分类

网路结构:

                                                              图来源

Lenet:输入32*32
conv1:    In-28*28 k:5  pd:2  s:1 -----out:(28+2*2-5)/1+1=28   6*28*28
maxpool1: In-28*28 k:2 s:2       -----out:(28-2)/2+1=14   16*14*14
conv2:    In-14*14 k:5    s:1 -----out:(14-5)/1+1=10   16*10*10
maxpool2: In-10*10 k:2 s:2       -----out:(10-2)/2+1=5   16*5*5
fc1:      In:5*5                 -----out:120
fc2:      In:120                 -----out:84
fc3:      In:84                 -----out:10

代码部分:

class Lenet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(3,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1=nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.fc3=nn.Linear(84,2)
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(x.size()[0],-1)
        x=self.fc1(x)
        x=self.fc2(x)
        x=self.fc3(x)

        return x

数据预处理:

# 数据预处理
transform = tranforms.Compose([
    tranforms.RandomResizedCrop(32),
    tranforms.ToTensor(),
    tranforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
])

root = './data/dog_cat'
dataset_train = datasets.ImageFolder(root + '/train', transform)
dataset_test = datasets.ImageFolder(root + '/test', transform)

trainloder=DataLoader(dataset_train,batch_size=batchs,shuffle=True)
testloder=DataLoader(dataset_test,batch_size=batchs,shuffle=True)

训练:

model=Lenet().to(device)
loss_fn=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=lr,momentum=0.9)



# 训练
if __name__=="__main__":
    for epoch in range(epochs):
        sumloss=0
        for i ,(x,y) in enumerate(trainloder):
            x,y=x.to(device),y.to(device)
            ypred=model(x)

            loss=loss_fn(ypred,y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sumloss+=loss.item()

            if i %100==99:
                print("[%d,%d] loss:%.03f"%(epoch+1,i+1,sumloss/100))
                sumloss=0.0
        with torch.no_grad():
            acc=0
            total=0
            for x,y in testloder:
                x,y=x.to(device),y.to(device)
                ypred=model(x)

                _,pred=torch.max(ypred.data,1)
                total+=y.size(0)
                acc+=(pred==y).sum()
            print('第%d个epoch的准确率为:%d%%'%(epoch+1,(100*acc/total)))
        torch.save(model.state_dict(),'%s/lenet_model_%03d.pth'%(opt.outf,epoch+1))

测试:

classes=['dog','cat']
if __name__=="__main__":
    # model = Lenet().to(device)
    model.load_state_dict(torch.load("model/lenet_model_100.pth"))
    # print(label)
    acc = 0.0
    total = 0.0
    with torch.no_grad():
        for x, y in testloder:
            x, y = x.to(device), y.to(device)
            ypred = model(x)
            _, pred = torch.max(ypred.data, 1)

            total += y.size(0)
            # (pred==y)得出true或false sum将
            acc += (pred == y).sum().item()
            # print(pred==y)
        print("Test Average accuracy is:{:.4f}%".format(100 * acc / total))


    calss_acc=[0.0 ,0.0]
    total_lb=[0,0]
    with torch.no_grad():
        acc=0
        total=0
        for x,y in testloder:
            x,y=x.to(device),y.to(device)
            ypred=model(x)

            _,pred=torch.max(ypred.data,1)

            c=(pred==y).squeeze()

            try:
                for i in range(batchs):
                    lb=label[i]
                    calss_acc[lb]+=c[i].item()
                    total_lb[lb]+=1
            except IndexError:
                continue
    for i in range(2):
        print('Accuracy of %5s : %4f %%' % (classes[i], 100 * calss_acc[i] / total_lb[i]))

    print(classes, total_lb)

 

 

 

 

 

 

 

 

 

 

 

 

 

 

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐