当前位置:网站首页>初步认识pytorch
初步认识pytorch
2022-06-27 08:47:00 【CNRalap】
初步认识pytorch
是什么
PyTorch 是一个基于 python 的科学计算包,有以下特性:
作为 NumPy 的替代品,可以利用 GPU 的性能进行计算。
作为一个高灵活性,速度快的深度学习平台。
有什么
TENSORS
张量如同数组和矩阵一样, 是一种特殊的数据结构,实际上就是一个多维数组。在PyTorch中, 神经网络的输入、输出以及网络的参数等数据, 都是使用张量来进行描述。张量的使用和Numpy中的ndarrays很类似, 区别在于张量可以在GPU或其它专用硬件上运行, 这样可以得到更快的加速效果。
张量表示由一个数值组成的数组,这个数组可能有多个维度。具有一个轴的张量对应数学上的向量(vector);具有两个轴的张量对应数学上的矩阵(matrix); 具有两个轴以上的张量没有特殊的数学名称。
张量对象的3个属性:
rank:number of dimensions(维度的数目)
shape: number of rows and columns(行列的数目)
type: data type of tensor’s elements(数据类型)
)
1.张量的索引和切片
就像在任何其他Python数组中一样,张量中的元素可以通过索引访问。与任何Python数组一样:第一个元素的索引是0,最后一个元素索引是-1;可以指定范围以包含第一个元素和最后一个之前的元素。
#x = torch.ones(3, 4)
#x[0] = 12
tensor([[12., 12., 12., 12.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]])
#x[-1]=12
tensor([[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[12., 12., 12., 12.]])
如果我们想为多个元素赋值相同的值,我们只需要索引所有元素,然后为它们赋值。例如,[0:2, :]访问第0行和第1行,其中“:”代表沿轴1(列)的所有元素。虽然我们讨论的是矩阵的索引,但这也适用于向量和超过2个维度的张量。
#X[0:2, :] = 12(从第0行开始数两行,即第0行第1行所有列均为12)
#X
tensor([[12., 12., 12., 12.],
[12., 12., 12., 12.],
[ 8., 9., 10., 11.]])
2.张量的拼接
通过torch.cat方法将一组张量按照指定的维度进行拼接, 也可以参考torch.stack方法。这个方法也可以实现拼接操作, 但和torch.cat稍微有点不同。
使用torch.cat,我们只需要提供张量列表,并给出沿哪个轴连结。下面的例子分别演示了当我们沿行(轴-0,形状的第一个元素,一行一行拼接) 和按列(轴-1,形状的第二个元素,一列一列拼接)连结两个矩阵时,会发生什么情况。
X = torch.arange(12, dtype=torch.float32).reshape((3,4))
Y = torch.tensor([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
torch.cat((X, Y), dim=0), torch.cat((X, Y), dim=1)
输出
#dim=0
(tensor([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 2., 1., 4., 3.],
[ 1., 2., 3., 4.],
[ 4., 3., 2., 1.]]),
#dim=1
tensor([[ 0., 1., 2., 3., 2., 1., 4., 3.],
[ 4., 5., 6., 7., 1., 2., 3., 4.],
[ 8., 9., 10., 11., 4., 3., 2., 1.]]))
3.张量的乘积和矩阵乘法
逐个元素相乘结果
print(f"tensor.mul(tensor): \n {tensor.mul(tensor)} \n")
等价写法:
print(f"tensor * tensor: \n {tensor * tensor}")
tensor.mul(tensor):
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
tensor * tensor:
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
张量与张量的矩阵乘法
print(f"tensor.matmul(tensor.T): \n {tensor.matmul(tensor.T)} \n")
# 等价写法:
print(f"tensor @ tensor.T: \n {tensor @ tensor.T}")
tensor.matmul(tensor.T):
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]])
tensor @ tensor.T:
tensor([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]])
4. 自动赋值运算
自动赋值运算通常在方法后有 _ 作为后缀, 例如: x.copy_(y), x.t_()操作会改变 x 的取值。
print(tensor, "\n")
tensor.add_(5)
print(tensor)
tensor([[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]])
tensor([[6., 5., 6., 6.],
[6., 5., 6., 6.],
[6., 5., 6., 6.],
[6., 5., 6., 6.]])
5.广播机制
a和b分别是3X1和1X2的矩阵,如果让它们相加,它们的形状不匹配。我们将两个矩阵广播为更大的3X2矩阵,如下所示:矩阵a将复制列,矩阵b将复制行,然后再按元素相加。
#a和b
(tensor([[0],
[1],
[2]]),
tensor([[0, 1]]))
#a+b
tensor([[0, 1],
[1, 2],
[2, 3]])
数据集预处理
CSV文件可以调用pandas中的read_csv函数进行查看,并用插值法或者删除法对缺失值进行处理。最关键的是,这些文件我们需要转换成张量格式。
最后处理完的数据:
#input
NumRooms Alley_Pave Alley_nan
0 3.0 1 0
1 2.0 0 1
2 4.0 0 1
3 3.0 0 1
#output
Price
127500
106000
178100
140000
inputs和outputs中的所有条目都是数值类型,它们可以转换为张量格式:
x, y = torch.tensor(inputs.values), torch.tensor(outputs.values)
print(x, y)
输出如下:
tensor([[3., 1., 0.],
[2., 0., 1.],
[4., 0., 1.],
[3., 0., 1.]], dtype=torch.float64) tensor([127500, 106000, 178100, 140000])
torch.autograd
torch.autograd是 PyTorch 的自动差分引擎,可为神经网络训练提供支持。首先我们来介绍一下神经网络,神经网络(NN)是在某些输入数据上执行的嵌套函数的集合。 这些函数由参数(由权重和偏差组成)定义,这些参数在 PyTorch 中存储在张量中。
训练 NN 分为两个步骤:
正向传播:在正向传播中,NN 对正确的输出进行最佳猜测。 它通过其每个函数运行输入数据以进行猜测。
反向传播:在反向传播中,NN 根据其猜测中的误差调整其参数。 它通过从输出向后遍历,收集有关函数参数(梯度)的误差导数并使用梯度下降来优化参数来实现。
让我们来看一个训练步骤。 对于此示例,我们从torchvision加载了经过预训练的 resnet18 模型。 我们创建一个随机数据张量来表示具有 3 个通道的单个图像,高度&宽度为 64,其对应的label初始化为一些随机值。
import torch, torchvision
model = torchvision.models.resnet18(pretrained=True)
data = torch.rand(1, 3, 64, 64)
labels = torch.rand(1, 1000)
接下来,我们通过模型的每一层运行输入数据以进行预测。这是正向传播:prediction = model(data) # forward pass
我们使用模型的预测和相应的标签来计算误差(loss)。 下一步是通过网络反向传播此误差。 当我们在误差张量上调用.backward()时,开始反向传播。 然后,Autograd 会为每个模型参数计算梯度并将其存储在参数的.grad属性中。
loss = (prediction - labels).sum()
loss.backward() # backward pass
接下来,我们加载一个优化器,在本例中为 SGD,学习率为 0.01,动量为 0.9。 我们在优化器中注册模型的所有参数。
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
最后,我们调用.step()启动梯度下降。 优化器通过.grad中存储的梯度来调整每个参数。
optim.step() #gradient descent
神经网络
可以使用torch.nn包构建神经网络。现在您已经了解了autograd,nn依赖于autograd来定义模型并对其进行微分。nn.Module包含层,以及返回output的方法forward(input)。
例如,查看以下对数字图像进行分类的网络:
(卷积网)
这是一个简单的前馈网络。它获取输入,将其一层又一层地馈入,然后最终给出输出。
神经网络的典型训练过程如下:
1.定义具有一些可学习参数(或权重)的神经网络
2.遍历输入数据集
3.通过网络处理输入
4.计算损失(输出正确的距离有多远)
损失函数采用一对(输出,目标)输入,并计算一个值,该值估计输出与目标之间的距离。nn包下有几种不同的损失函数。 一个简单的损失是:nn.MSELoss,它计算输入和目标之间的均方误差。
5.将梯度传播回网络参数
要反向传播误差,我们要做的只是对loss.backward()。 不过,您需要清除现有的梯度,否则梯度将累积到现有的梯度中。
6.通常使用简单的更新规则来更新网络的权重:weight = weight - learning_rate * gradient
使用的最简单的更新规则是随机梯度下降(SGD),或者不同的更新规则,例如 SGD,Nesterov-SGD,Adam,RMSProp。
训练分类器
当您必须处理图像,文本,音频或视频数据时,可以使用将数据加载到 NumPy 数组中的标准 Python 包。 然后,您可以将该数组转换为torch.*Tensor。
对于图像,Pillow,OpenCV 等包很有用
对于音频,请使用 SciPy 和 librosa 等包
对于文本,基于 Python 或 Cython 的原始加载,或者 NLTK 和 SpaCy 很有用
专门针对视觉,我们创建了一个名为torchvision的包,其中包含用于常见数据集(例如 Imagenet,CIFAR10,MNIST 等)的数据加载器,以及用于图像(即torchvision.datasets和torch.utils.data.DataLoader)的数据转换器。
这提供了极大的便利,并且避免了编写样板代码。
训练分类器大致有以下步骤:
1.使用torchvision加载并标准化 CIFAR10 训练和测试数据集(本例中用CIFAR10)
2.定义卷积神经网络
3.定义损失函数
4.根据训练数据训练网络
5.在测试数据上测试网络
一些必备知识
神经网络中的 Epochs, Batchsize, Iterations 具体是什么
梯度下降:
梯度下降法是机器学习中经典的优化算法之一,用来求解复杂曲线的最小值。“梯度”是指某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模)。“下降”是指下降递减的过程。梯度下降法是多次迭代求解的,梯度下降的迭代质量有助于使模型尽可能拟合训练数据。即我们可以简单理解为,控制梯度来控制曲线的变化,从而使模型尽可能的拟合训练数据。
batch:
前文提及到了梯度下降,梯度下降每次的参数更新有两种方式:
第一种,遍历全部数据集算一次损失函数,然后算函数对各个参数的梯度,更新梯度。这种方法每更新一次参数都要把数据集里的所有样本都看一遍,计算量开销大,计算速度慢,不支持在线学习,这称为Batch gradient descent,批梯度下降。
第二种,每看一个数据就算一下损失函数,然后求梯度更新参数,这个称为随机梯度下降,stochastic gradient descent。这个方法速度比较快,但是收敛性能不太好,可能在最优点附近晃来晃去,hit不到最优点。两次参数的更新也有可能互相抵消掉,造成目标函数震荡的比较剧烈。
为了克服两种方法的缺点,现在一般采用的是一种折中手段,mini-batch gradient decent,小批的梯度下降,这种方法把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大。基本上现在的梯度下降都是基于mini-batch的,所以深度学习框架的函数中经常会出现batch_size,就是指这个。
epoch:
为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策。超参数之一是定型周期(epoch)的数量,epoch被定义为向前和向后传播中所有批次的单次训练迭代,这意味着1个epoch是将所有的数据输入网络完成一次向前计算及反向传播。简单说,epochs指的就是训练过程中数据将被“轮”多少次,亦即应当完整遍历数据集多少次(一次为一个epoch)。
如果epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。
iterations:
iteration即为迭代,每一次迭代都是一次权重更新,每一次权重更新需要batch_size个数据进行Forward运算得到损失函数,再BP算法更新参数。所以,iterations就是完成一次epoch所需的batch个数。batch numbers就是iterations。
举个例子
训练集有1000个样本,batchsize=10,那么训练完整个样本集需要:
100次iteration,1次epoch。
具体的计算公式为:
one epoch = numbers of iterations = N = 训练样本的数量/batch_size
怎么用
安装pytorch
安装教程:PyTorch 最新安装教程(2021-07-27)
运用pytorch
背景知识:
图像分类,就是对于一个给定的图像,预测它属于哪个类别签。图像是3维数组,数组元素是取值范围从0到255的整数。数组的尺寸是宽度x高度x3,其中这个3代表的是红、绿和蓝3个颜色通道。
CNN
一个简单的卷积神经网络(CNN)是由各种层按照顺序排列组成,网络中的每个层使用一个可微分的函数将数据从一层传递到下一层。卷积神经网络主要由三种类型的层构成:卷积层,池化层和全连接层。通过将这些层叠加起来,就可以构建一个完整的卷积神经网络。
(1)卷积层:
卷积层可以说是卷积神经架构中最重要的步骤之一,涉及到特征表达的好坏,同时也是占据整个网络95%以上的计算量。卷积是一种线性的、平移不变性的运算。
(2)非线性激活单元:
非线性激活单元受启发于人类大脑的神经元模型。在神经元模型中,树突将信号传递到细胞体,信号在细胞体中组合相加。如果最终之和高于某个阈值,那么神经元将会激活,向其轴突输出一个峰值信号传递至下一个神经元。
引入非线性激活函数的主要目的是增加神经网络的非线性性。因为如果没有非线性激活函数的话,每一层输出都是上层输入的线性函数,因此,无论神经网络有多少层,得到的输出都是线性函数,这就是原始的感知机模型,这种线性性不利于发挥神经网络的优势。
常用的非线性激活单元如下图所示。目前比较常用的有ReLu和LReLu,Logistic(Sigmoid)单元由于其饱和区特性导致整个网络梯度消失而逐渐退出历史舞台(有的时候最后一层会用Sigmoid将输出限制在0.0-1.0)。
(3)池化层:
通常,在连续的卷积层之间会周期性地插入一个池化层。它的作用是逐渐降低数据体的空间尺寸,这样的话就能减少网络中参数的数量,使得计算资源耗费变少,也能有效控制过拟合,如下图所示。池化层通常使用MAX操作,对输入数据体的每一个切片独立进行操作,改变它的空间尺寸。最常见的形式是使用尺寸2x2的滤波器,以步长为2来对每个深度切片进行降采样,将其中75%的激活信息都丢掉。每个MAX操作是从4个数字中取最大值(也就是在深度切片中某个2x2的区域)。注意在池化的过程中,数据体的通道数保持不变。
(4)全链接层:
所谓全链接层即是传统的神经网络,每一个神经单元都和上一层所有的神经单元密集连接。如今,全连接层由于其巨大的参数量易过拟合以及不符合人类对图像的局部感知原理,一般不参与图像的特征提取(已由卷积层替代),只用于最后的线性分类, 相当于在提取的高层特征向量上进行线性组合并且输出最后的预测结果。
ResNet
残差网络(Residual Network)是ILSVRC 2015的冠军模型,由何恺明等实现。ResNet的结构可以极快的加速神经网络的训练,模型的准确率也有比较大的提升。同时ResNet的推广性非常好,甚至可以直接用到InceptionNet网络中。
在ResNet网络中有如下几个亮点:
提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)。
使用Batch Normalization加速训练(丢弃dropout)。
传统卷积神经网络都是通过将一系列卷积层与下采样层进行堆叠得到的。但是当堆叠到一定网络深度时,就会出现两个问题:
梯度消失或梯度爆炸。(深层网络的结构问题或不合适的损失函数)
退化问题(degradation problem)。(加深网络的层数希望深层的网络的表现能比浅层好,或者是希望它的表现至少和浅层网络持平(相当于直接复制浅层网络的特征),但实际的结果却是让深度网络退化,形成原因可能是非线性激活函数Relu的存在,每次输入到输出的过程都几乎是不可逆的,这也造成了许多不可逆的信息损失。)
在ResNet论文中说通过数据的预处理以及在网络中使用BN(Batch Normalization)层能够解决梯度消失或者梯度爆炸问题,并提出了residual结构(残差结构)来减轻退化问题。
残差结构:通过对深度网络退化问题的认识我们已经明白,要让之不退化,根本原因就是如何做到恒等映射。事实上,已有的神经网络很难拟合潜在的恒等映射函数H(x) = x。但如果把网络设计为H(x) = F(x) + x,即直接把恒等映射作为网络的一部分,就可以把问题转化为学习一个残差函数F(x) = H(x) - x。只要F(x) = 0,就构成了一个恒等映射H(x) = x。 而且,拟合残差至少比拟合恒等映射容易得多。当然,实际情况下残差F(x)不会为0,x肯定是很难达到最优的,但是总会有那么一个时刻它能够无限接近最优解。采用ResNet的话,就只用小小的更新F(x)部分的权重值就行了,可以更好地学到新的特征。我们看一下残差结构与正常结构对比图:
我们可以看到,残差结构比正常的结构多了右侧的曲线,这个曲线也叫作shortcut connection,通过跳接在激活函数前,将上一层(或几层)的输出与本层输出相加,将求和的结果输入到激活函数作为本层的输出。
而残差结构又是由ResNet block组成的,ResNet block有两种,一种两层结构,一种是三层的bottleneck结构,即将两个3x3的卷积层替换为1x1 + 3x3 + 1x1,它通过1x1 conv来巧妙地缩减或扩张feature map维度,从而使得我们的3x3 conv的filters数目不受上一层输入的影响,它的输出也不会影响到下一层。中间3x3的卷积层首先在一个降维1x1卷积层下减少了计算,然后在另一个1x1的卷积层下做了还原。既保持了模型精度又减少了网络参数和计算量,节省了计算时间。
实际运用:识别花朵
0.项目一览
其中1到4为测试网络的验证集里的图片。
json为一个配置文件,给每一类花做标记。
main为新建项目默认文件,无用。
resNET34.path为模型权重文件。
1.构建模型
#model.py
import torch.nn as nn
import torch
#18/34
class BasicBlock(nn.Module):
expansion = 1 #每一个conv的卷积核个数的倍数
def __init__(self, in_channel, out_channel, stride=1, downsample=None):#downsample对应虚线残差结构
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)#BN处理
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x #捷径上的输出值
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
#50,101,152
class Bottleneck(nn.Module):
expansion = 4#4倍
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU(inplace=True)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU(inplace=True)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,#输出*4
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x): #正向传播
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, blocks_num, num_classes=1000, include_top=True):#block残差结构 include_top为了之后搭建更加复杂的网络
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0])
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)自适应
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel, channel))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet34(num_classes=1000, include_top=True):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
2.下载数据集
DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
3.将数据集分类
先执行split.py函数:
#spile_data.py
import os
from shutil import copy
import random
def mkfile(file):
if not os.path.exists(file):
os.makedirs(file)
file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
mkfile('flower_data/train/'+cla)
mkfile('flower_data/val')
for cla in flower_class:
mkfile('flower_data/val/'+cla)
split_rate = 0.1
for cla in flower_class:
cla_path = file + '/' + cla + '/'
images = os.listdir(cla_path)
num = len(images)
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
image_path = cla_path + image
new_path = 'flower_data/val/' + cla
copy(image_path, new_path)
else:
image_path = cla_path + image
new_path = 'flower_data/train/' + cla
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
将数据集分类成:
3.执行train,根据训练数据训练网络
train文件如下:
#train.py
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import json
import matplotlib.pyplot as plt
import os
import torch.optim as optim
from ResNet_model import resnet34, resnet101
import torchvision.models.resnet
if __name__ == '__main__': #进入多线程,把要运行的写到main中
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), # 来自官网参数
"val": transforms.Compose([transforms.Resize(256), # 将最小边长缩放到256
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# data_root = os.getcwd()
image_path = r'E:\迅雷下载\flower_data/'
train_dataset = datasets.ImageFolder(root=image_path + "train",
transform=data_transform["train"])
train_num = len(train_dataset)
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=2)
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=batch_size, shuffle=False,
num_workers=2)
# net = resnet34()
net = resnet34(num_classes=5)
# load pretrain weights
# model_weight_path = "./resnet34-pre.pth"
# missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_weight_path), strict=False)#载入模型参数
# for param in net.parameters():
# param.requires_grad = False
# change fc layer structure
# inchannel = net.fc.in_features
# net.fc = nn.Linear(inchannel, 5)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
best_acc = 0.0
save_path = './resNet34.pth'
for epoch in range(3):
# train
net.train()
running_loss = 0.0
for step, data in enumerate(train_loader, start=0):
images, labels = data
optimizer.zero_grad()
logits = net(images.to(device))
loss = loss_function(logits, labels.to(device))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
# print train process
rate = (step + 1) / len(train_loader)
a = "*" * int(rate * 50)
b = "." * int((1 - rate) * 50)
print("\rtrain loss: {:^3.0f}%[{}->{}]{:.4f}".format(int(rate * 100), a, b, loss), end="")
print()
# validate
net.eval()
acc = 0.0 # accumulate accurate number / epoch
with torch.no_grad():
for val_data in validate_loader:
val_images, val_labels = val_data
outputs = net(val_images.to(device)) # eval model only have last output layer
# loss = loss_function(outputs, test_labels)
predict_y = torch.max(outputs, dim=1)[1]
acc += (predict_y == val_labels.to(device)).sum().item()
val_accurate = acc / val_num
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, running_loss / step, val_accurate))
print('Finished Training')
运行结果如下:
第二次训练中发现GPU的显存没有完全利用,可能参数设置仍待调整。
4.执行predict文件,在测试数据上测试网络
predict文件如下:
#predict.py
import torch
from ResNet_model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img = Image.open("4.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
try:
json_file = open('./class_indices.json', 'r')
class_indict = json.load(json_file)
except Exception as e:
print(e)
exit(-1)
# create model
model = resnet34(num_classes=5)
# load model weights
model_weight_path = "./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()
运行结果:(借助了Matplotlib这个数据可视化库)
书籍推荐
边栏推荐
- How Oracle converts strings to multiple lines
- A classic interview question covering 4 hot topics
- The background prompt module for accessing fastadmin after installation does not exist
- Fake constructor???
- 数字IC-1.9 吃透通信协议中状态机的代码编写套路
- Order by injection of SQL injection
- About the problem that the El date picker Click to clear the parameter and make it null
- 2022.06.26(LC_6100_统计放置房子的方式数)
- 【原创】TypeScript字符串utf-8编码解码
- DataV轮播表组件dv-scroll-board宽度问题
猜你喜欢

经典的一道面试题,涵盖4个热点知识

AQS underlying source code of concurrent programming JUC

NoSQL database redis installation

Matlab tips (18) matrix analysis -- entropy weight method

Enumeration? Constructor? Interview demo

一种太阳能电荷泵供电电路的方案设计

Correctly understand MySQL mvcc

Redis的持久化机制
![[cloud native] 2.3 kubernetes core practice (Part 1)](/img/f8/dbd2546e775625d5c98881e7745047.png)
[cloud native] 2.3 kubernetes core practice (Part 1)

Digital ic-1.9 understands the coding routine of state machine in communication protocol
随机推荐
Redis的事务
MySQL environment variable configuration tutorial
fastadmin 安装后访问后台提示模块不存在
When multiple network devices exist, how to configure their Internet access priority?
Redis配置文件详解
Analysis of orthofinder lineal homologous proteins and result processing
100% understanding of 5 IO models
即构「畅直播」,全链路升级的一站式直播服务
This, constructor, static, and inter call must be understood!
Design of a solar charge pump power supply circuit
Markem imaje马肯依玛士喷码机维修9450E打码机维修
RockerMQ消息发送模式
Persistence mechanism of redis
三道基础面试题总结
IMX8QXP DMA资源和使用(未完结)
[cloud native] 2.3 kubernetes core practice (Part 1)
March into machine learning -- Preface
Win10 add right-click menu for any file
Redis master-slave replication and sentinel mode
Filter filter