鞋子,靴子,拖鞋傻傻分不清楚 pytorch实现分类 入门小案例

news/2024/7/21 7:33:08 标签: 计算机视觉, 深度学习, 图像处理, 分类

鞋子,靴子,拖鞋傻傻分不清楚 pytorch入门

  • 前言
  • 方法
    • 网络
    • 优化器
    • 损失函数
    • 总体方法
  • 代码实现
    • 图片加字
    • 神经网络
  • 总结

前言

从入学到现在已经两个多月了,看了一个多月的论文不知道学到了啥
在这里插入图片描述
正好最近看了看pytorch的入门,像休息休息,就想着写个分类玩玩吧,但不知道写啥,突然见看到一个数据集网站,有一个鞋子的数据集
在这里插入图片描述
这对我这种非常like鞋的人来说很有吸引力,那来整个鞋子分类吧。

在这里插入图片描述

方法

网络

这里我们选用的网络是DenseNet,相比于普通的CNN来说,Densenet可以使用各层提取的特征,从而避免特征的丢失,同样,我们再进行分类也希望尽可能多的特征为我们所用,Densenet的网络结构如下
在这里插入图片描述

优化器

使用的是随机梯度下降优化器,其中学习率设置为0.001,动量为0.5
具体的内容可以参考该博客详解随机梯度下降法(Stochastic Gradient Descent,SGD)

损失函数

使用的是交叉熵损失函数

这里多说几句,最开始我看到这个损失的时候一脸懵逼,咋着,一个数还能和一个向量比较一下子
在这里插入图片描述
直到后面我看到b站的讲解,好吧,可能!

在这里插入图片描述
这就是b站里给到的公式

这里的x是指图像真实的类别,class是指图像在该类别的得分,x[j]是指所有类别在预测后的得分

以我们要介绍的网络为例,因为要判断的只有三类,鞋子,靴子和凉鞋,则我们最终的输出维度是3,即最终会输出一个向量,这个向量有三个值,分别代表分为鞋子,靴子和凉鞋的概率得分,越大就表示图片属于这一类的可能性越大。

例如我们输出的结果为[0.1 , 0.2, 0.8]

假如我们要输入的图片类别是第2类(从0开始算),那么带入上面公式就是

-2*0.8+log(e(0.1)+e(0.2)+e(0.3))

则当上面公式越小时,越接近真实结果

总体方法

这样就很简单了,就是利用我们现有的1.5万照片去训练该网络(这里做的比较糙,没有设置验证集和测试集),采用随机梯度下降的方式进行训练,每次训练的图片数量为10,所有图片被作为输入训练一次后为一个epoch,总共训练50个epoch,训练结束后就是我们需要的模型了。

代码实现

图片加字

在判断类别后在图片上加上类别,方便看

from PIL import ImageFont, ImageDraw, Image
import numpy as np
import cv2


def settags(info, img):
    url = img
    img_cv = cv2.imread(img)
    img = Image.fromarray(img_cv)
    font1 = ImageFont.truetype("./simsun.ttc", 100)
    draw = ImageDraw.Draw(img)
    draw.text((10, 10), info, font=font1, fill=(0, 0, 255))
    img1 = np.array(img)
    cv2.imwrite('r' + url, img1)

神经网络

卷积层


class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False):
        super(ConvLayer, self).__init__()
        reflection_padding = int(np.floor(kernel_size / 2))
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.dropout = nn.Dropout2d(p=0.5)
        self.is_last = is_last

    def forward(self, x):
        # 图片进行填充 保证输出都是相同大小,从而才能使后面的层也使用前面的特征
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        if self.is_last is False:
            out = F.leaky_relu(out, inplace=True)
        return out

单个densenet

# Dense convolution unit
class DenseConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, is_Last=False):
        super(DenseConv2d, self).__init__()
        self.dense_conv = ConvLayer(in_channels, out_channels, kernel_size, stride, is_Last)
        self.is_last = is_Last

    def forward(self, x):
        out = self.dense_conv(x)
        if self.is_last == False:
            # 按第二个维度进行拼接    为了实现densenet
            out = torch.cat([x, out], 1)
        return out

所有DenseNet

out_channels_def = 16
        denseblock = []
        # densenet
        denseblock += [DenseConv2d(in_channels, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def * 2, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def * 3, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def * 4, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def * 5, out_channels_def, kernel_size, stride),
                       DenseConv2d(in_channels + out_channels_def * 6, out_channels_def, kernel_size, stride, True)
                       ]
        self.denseblock = nn.Sequential(*denseblock)

训练

def train():
    net = Net(3, 3, 1)
    net.cuda()
    net.train()
    # 训练
    cirterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.5)
    for epoch in range(50):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            inputs, labels = Variable(inputs), Variable(labels)
            inputs = torch.tensor(inputs)
            labels = torch.tensor(labels)
            inputs = inputs.cuda()
            labels = labels.cuda()
            optimizer.zero_grad()  # 优化器清零
            outputs = net(inputs)
            loss = cirterion(outputs, labels)
            loss.backward()
            optimizer.step()  # 优化
            running_loss += loss.item()
            if i % 5 == 0:
                print('[%d %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
        # 每个epoch存储一个模型
        torch.save(net, 'shoenet' + epoch.__str__() + '.pth')
    torch.save(net, 'shoenet.pth')

总结

第一次写神经网络的代码,花了好久才搞出来,虽然很多代码都是照葫芦画瓢,但对我这种小白来说确实蛮难的,但总归是搞出来了,来看下结果把。
在这里插入图片描述
可以看到哈,分类还是ok 的,那就到这了,溜了溜了


http://www.niftyadmin.cn/n/5172.html

相关文章

一文了解SpringBoot

目录 什么是SpringBoot? SpringBoot的优点 SpringBoot项目搭建 创建一个普通的maven项目 修改pom.xml(检查引入的依赖是否正确) 准备SpringBoot的启动配置文件 开发SpringBoot的启动类 输出HelloWord 注意事项 SpringBoot的配置文件…

项目实战 | YOLOv5 + Tesseract-OCR 实现车牌号文本识别

项目实战 | YOLOv5 Tesseract-OCR 实现车牌号文本识别 最近看到了各种各样的车牌识别,觉得挺有意思,自己也简单搞一个玩玩😼。 传统的图像处理算法我也不太会,就直接用深度学习的方法实现吧。 文章目录项目实战 | YOLOv5 Tesser…

C++ opencv图像存储和MAT容器

1.图像在内存之中的存储方式: 图像矩阵的大小取决于所用的颜色模型,确切说,取决于所用通道数。如果是灰度图像,矩阵就会如图5.1所示。 对于多通道图像来说,矩阵中的列会包含多个子列,其子列个数与通道数相同&#xf…

【优化调度】基于matlab遗传算法求解公交车调度排班优化问题【含Matlab源码 2212期】

⛄ 一、 遗传算法简介 1 引言 公交排班问题是城市公交调度的核心内容,是公交调度人员、司乘人员进行工作以及公交车辆正常运行的基本依据。行车时刻表是按照线路的当前客流量情况,确定发车频率,提供线路车辆的首、末车时间。它是公交企业对社会的承诺,决定着为乘客服务的水平,…

Hadoop的eclipse搭建(客观莫划走,留下来看一眼(适用人群学生初学,其他人看看就行))

前言:Hadoop的eclipse搭建是建立在Hadoop的安装之后进行的,因为Linux上的Hadoop和Windows上的Hadoop版本要求一致,不一致可能会出现某些问题 准备工作:Java的安装包、eclipse的安装包、Hadoop的包(Windows的Hadoop安装…

计算机二级真题练习

1、下面不符合软件设计准则的是()。 A、设计单入口、单出口的模块 B、模块规模尽可能小 C、提高模块的独立性 D、减少模块接口和界面的复杂性 正确答案:B 笞疑:【解析】软件设计准则:Q提高模块独立性;②模块规摸应该适中;③深度、宽度、扇出和扇入都…

【毕业设计】口罩佩戴检测系统 - opencv 卷积神经网络 机器视觉 深度学习

文章目录🚩 0 简介🚩1 课题背景🚩 2 口罩佩戴算法实现2.1 YOLO 模型概览2.2 YOLOv32.3 YOLO 口罩佩戴检测实现2.4 实现代码2.5 检测效果🚩 3 口罩佩戴检测算法评价指标3.1 准确率(Accuracy)3.2 精确率(Prec…

加权黑猩猩优化算法(WChOA)附Matlab代码

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …