卷积网络:实现手写数字是识别50轮准确率97.3%

news/2024/7/21 7:48:29 标签: python, 深度学习, 图像处理

卷积网络:实现手写数字是识别50轮准确率

  • 1 导入必备库
  • 2 torchvision内置了常用数据集和最常见的模型
  • 3 数据批量加载
  • 4 绘制样例
  • 5 创建模型
  • 7 设置是否使用GPU
  • 8 设置损失函数和优化器
  • 9 定义训练函数
  • 10 定义测试函数
  • 11 开始训练
  • 12 绘制损失曲线并保存
  • 13 绘制准确率曲线并保存

1 导入必备库

python">import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)

输出:

1.12.1+cu102

2 torchvision内置了常用数据集和最常见的模型

python">import torchvision
from torchvision.transforms import ToTensor
''' transforms.ToTensor    
    1.转化为一个 tensor
    2.转换到0-1之间
    3.会将channel放在第一维度上
'''
train_ds = torchvision.datasets.MNIST('data/',
                                      train=True,
                                      transform=ToTensor(),
                                      download=False
                                     )
test_ds = torchvision.datasets.MNIST('data/',
                                     train=False,
                                     transform=ToTensor(),
                                     download=False  
                                    )
print(len(train_ds),len(test_ds))

输出:

60000 10000

3 数据批量加载

python">train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=256)

# iter方法创建生成器,next方法返回一个批次的图像,shape属性返回一批次张量形状
imgs, labels = next(iter(train_dl))
print(imgs.shape)
print(labels.shape)

输出:

torch.Size([64, 1, 28, 28])
torch.Size([64])

4 绘制样例

python">plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
    npimg = img.numpy()
    npimg = np.squeeze(npimg)
    plt.subplot(1, 10, i+1)
    plt.imshow(npimg)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(labels[i].numpy())
    # plt.axis('off') #关闭显示坐标
    plt.savefig('pics/3.1.jpg', dpi=400)

1

5 创建模型

python">class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)   
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.liner_1 = nn.Linear(16*4*4, 256)
        self.liner_2 = nn.Linear(256, 10)
    def forward(self, input):
        x = F.max_pool2d(F.relu(self.conv1(input)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 16*4*4)
        x = F.relu(self.liner_1(x))
        x = self.liner_2(x)
        return x

7 设置是否使用GPU

python">device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

# 将模型移动到DEVICE
model = Model().to(device)
print(model)

输出:

Using cuda device
Model(
  (liner_1): Linear(in_features=784, out_features=120, bias=True)
  (liner_2): Linear(in_features=120, out_features=84, bias=True)
  (liner_3): Linear(in_features=84, out_features=10, bias=True)
)

8 设置损失函数和优化器

python">loss_fn = torch.nn.CrossEntropyLoss()  # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

9 定义训练函数

python">def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    train_loss, correct = 0, 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            train_loss += loss.item()
    train_loss /= size
    correct /= size
    return train_loss, correct

10 定义测试函数

python">def test(dataloader, model):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    return test_loss, correct

11 开始训练

python">epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
    epoch_test_loss, epoch_test_acc = test(test_dl, model)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    
    template = ("epoch:{:2d}/{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ," 
                "test_loss: {:.5f}, test_acc: {:.1f}%")
    print(template.format(
          epoch+1,epochs, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
    
print("Done!")

输出:

epoch: 1/50, train_loss: 0.03559, train_acc: 24.1% ,test_loss: 0.00899, test_acc: 39.7%
epoch: 2/50, train_loss: 0.03413, train_acc: 51.0% ,test_loss: 0.00827, test_acc: 59.9%
epoch: 3/50, train_loss: 0.02756, train_acc: 62.9% ,test_loss: 0.00527, test_acc: 71.5%
······
epoch:48/50, train_loss: 0.00158, train_acc: 97.0% ,test_loss: 0.00037, test_acc: 97.1%
epoch:49/50, train_loss: 0.00155, train_acc: 97.0% ,test_loss: 0.00035, test_acc: 97.3%
epoch:50/50, train_loss: 0.00153, train_acc: 97.0% ,test_loss: 0.00035, test_acc: 97.3%
Done!

12 绘制损失曲线并保存

python">plt.plot(range(1, epochs+1), train_loss, label='train_loss', lw=2)
plt.plot(range(1, epochs+1), test_loss, label='test_loss', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-5.jpg', dpi=400)

输出:
在这里插入图片描述

13 绘制准确率曲线并保存

python">plt.plot(range(1, epochs+1), train_acc, label='train_acc', lw=2)
plt.plot(range(1, epochs+1), test_acc, label='test_acc', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-6.jpg', dpi=400)

输出:
在这里插入图片描述


http://www.niftyadmin.cn/n/5023862.html

相关文章

嵌入式学习笔记(25)串口通信的基本原理

三根通信线:Tx Rx GND (1)任何通信都要有信息作为传输载体,或者有线的或则无线的。 (2)串口通信时有线通信,是通过串口线来通信的。 (3)串口通信最少需要2根&#xff…

关于单片机的分频定时器的记录

记录一内部时钟: 对于单片机的频率原来一直不太明白,现在在学习进行记录: 主频: 以一个72M的STM32单片机作为主频为例子,这个72M主频说得是一秒钟产生72000000(七千两百万)个脉冲或周期&…

Explain 性能分析

目录 1. 能干什么 2. 如何分析 3. 各字段解释 1. 能干什么 使用 explainsql 的方式,分析查询语句的性能瓶颈。 ① 表的读取顺序; ② 数据读取操作的操作类型; ③ 哪些索引可以使用; ④ 哪些索引被实际使用; ⑤ 表之…

版本控制工具Git集成IDEA的学习笔记(第三篇Git私服)

本文章仅讲解代码提交和分支合并。 目录 一、提交并推送到私服 二、遇到的问题 1、将代码克隆到本地,idea打开,不使用主分支(master)进行开发操作。 2、也不使用dev分支进行开发操作,而是新建一个功能分支进行模块…

成都都市圈公共图书馆《乡村振兴战略下传统村落文化旅游设计》许少辉八一新著

成都都市圈公共图书馆《乡村振兴战略下传统村落文化旅游设计》许少辉八一新著

离线语音识别PocketSphinx(一)

总述 对于设备的控制,最简单方便的交互当属语音控制了,目前市面上也有许多的离线语音控制模块,可以任意更换需要识别的语句,但是识别模型这块都是闭源的,能够配置改动的不多,PocketSphinx是一个开源的离线…

栈的应用-综合计数器的实现

目录 前言 一、思路分析 二、代码实现 总结 前言 在实现综合计数器之前,大家应该先了解一下什么是前中后缀表达式 前缀、中缀和后缀表达式是表示数学表达式的三种不同方式。 前缀表达式(也称为波兰式或前缀记法):操作符位于操作数之前。…

2023百度十大科技前沿发明发布,超70%为大模型重构与创新

2023年9月12日,以“专利协同前沿创新,共筑AI原生未来”为主题的“2023百度十大科技前沿发明”发布会在北京召开。十大前沿发明中,超过70%涉及大模型和重构创新,一批创新AI原生应用落地,大量高价值专利成果披露。 百度首…