程序笔记   发布时间:2022-07-16  发布网站:大佬教程  code.js-code.com
大佬教程收集整理的这篇文章主要介绍了Pytorch CIFAR10图像分类 数据加载与可视化篇大佬教程大佬觉得挺不错的,现在分享给大家,也给大家做个参考。

Pytorch CIFAR10图像分类 数据加载与可视化篇

文章目录

  • Pytorch CIFAR10图像分类 数据加载与可视化篇
      • 1.数据读取
      • 2. 查看数据(格式࿰c;大小࿰c;形状)
      • 3. 查看图片
        • np.ndarray转为torch.Tensor
Pytorch一般有以下几个流程

  1. 数据读取
  2. 数据处理
  3. 搭建网络
  4. 模型训练
  5. 模型上线

这里会先讲一下关于CIFAR10的数据加载和图片可视化࿰c;之后的模型篇会对网络进行介绍和实现。

1.数据读取

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( arplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ࿰c;数据集中一共有 50000 张训练圄片和 10000 张测试图片。

与 MNIST 数据集中目比࿰c; CIFAR-10 具有以下不同点:

  • CIFAR-10 是 3 通道的彩色 RGB 图像࿰c;而 MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为 32×32࿰c; 而 MNIST 的图片尺寸为 28×28࿰c;比 MNIST 稍大。
  • 相比于手写字符࿰c; CIFAR-10 含有的是现实世界中真实的物体࿰c;不仅噪声很大࿰c;而且物体的比例、 特征都不尽相同࿰c;这为识别带来很大困难。

Pytorch CIFAR10图像分类 数据加载与可视化篇

首先使用torchvision加载和归一化我们的训练数据和测试数据。

a、torchvision这个东西࿰c;实现了常用的一些深度学习的相关的图像数据的加载功能࿰c;比如Cifar10、Imagenet、Mnist等等的࿰c;保存在torchvision.datasets模块中。

b、同时࿰c;也封装了一些处理数据的方法。保存在torchvision.transforms模块中

c、还封装了一些模型和工具封装在相应模型中,比如torchvision.models当中就包含了AlexNet࿰c;VGG࿰c;ResNet࿰c;SqueezeNet等模型。

由于torchvision的datasets的输出是[0,1]的PILImage࿰c;所以我们先先归一化为[-1,1]的Tensor

首先定义了一个变换transform࿰c;利用的是上面提到的transforms模块中的Compose( )把多个变换组合在一起࿰c;可以看到这里面组合了ToTensor和Normalize这两个变换

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))前面的(0.5࿰c;0.5࿰c;0.5) 是 R G B 三个通道上的均值࿰c; 后面(0.5, 0.5, 0.5)是三个通道的标准差࿰c;注意通道顺序是 R G B ࿰c;用过opencv的同学应该知道openCV读出来的图像是 BRG顺序。这两个tuple数据是用来对RGB 图像做归一化的࿰c;如其名称 Normalize 所示这里都取0.5只是一个近似的操作࿰c;实际上其均值和方差并不是这么多࿰c;但是就这个示例而言 影响可不计。精确值是通过分别计算R,G,B三个通道的数据算出来的。

transform = transforms.Compose([
#     transforms.CenterCrop(224),
    transforms.RandomCrop(32,padding=4), # 数据增广
    transforms.RandomHorizontalFlip(),  # 数据增广
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]) 

Trainloader其实是一个比较重要的东西࿰c;我们后面就是通过Trainloader把数据传入网络࿰c;当然这里的Trainloader其实是个变量名࿰c;可以随便取࿰c;重点是他是由后面的torch.utils.data.DataLoader()定义的࿰c;这个东西来源于torch.utils.data模块

Batch_Size = 256
Trainset = datasets.CIFAR10(root='./data', Train=True,download=True, transform=transform)
testset = datasets.CIFAR10(root='./data',Train=false,download=True,transform=transform)
Trainloader = torch.utils.data.DataLoader(Trainset, batch_size=Batch_Size,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=Batch_Size,shuffle=True, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified

2. 查看数据(格式࿰c;大小࿰c;形状)

首先可以查看类别

classes = Trainset.classes
classes
['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']
Trainset.class_to_idx
{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

也可以查看一下训练集的数据

Trainset.data.shape #50000是图片数量࿰c;32x32是图片大小࿰c;3是通道数量RGB
(50000, 32, 32, 3)

查看数据类型

#查看数据类型
print(type(Trainset.data))
print(type(Trainset))
<class 'numpy.ndarray'>
<class 'torchvision.datasets.cifar.CIFAR10'>

总结

Trainset.data.shape是标准的numpy.ndarray类型࿰c;其中50000是图片数量࿰c;32x32是图片大小࿰c;3是通道数量RGB; Trainset是标准的??类型࿰c;其中50000为图片数量࿰c;0表示取前面的数据࿰c;2表示3通道数RGB࿰c;32*32表示图片大小

3. 查看图片

import numpy as np
import matplotlib.pyplot as plt
plt.imshow(Trainset.data[0])
im,label = iter(Trainloader).next()

Pytorch CIFAR10图像分类 数据加载与可视化篇

np.ndarray转为torch.Tensor

在深度学习中࿰c;原始图像需要转换为深度学习框架自定义的数据格式࿰c;在pytorch中࿰c;需要转为torch.Tensor。 pytorch提供了torch.Tensornumpy.ndarray转换为接口:

方法名作用
torch.from_numpy(xxX)numpy.ndarray转为torch.Tensor
tensor1.numpy()获取tensor1对象的numpy格式数据

torch.Tensor 高维矩阵的表示: N x C x H x W

numpy.ndarray 高维矩阵的表示:N x H x W x C

因此在两者转换的时候需要使用numpy.transpose( ) 方法 。

def imshow(img):
    img = img / 2 + 0.5
    img = np.transpose(img.numpy(),(1,2,0))
    plt.imshow(img)
imshow(im[0])

Pytorch CIFAR10图像分类 数据加载与可视化篇

plt.figure(figsize=(8,12))
imshow(torchvision.utils.@H_240_105@make_grid(im[:32]))

Pytorch CIFAR10图像分类 数据加载与可视化篇

大佬总结

以上是大佬教程为你收集整理的Pytorch CIFAR10图像分类 数据加载与可视化篇全部内容,希望文章能够帮你解决Pytorch CIFAR10图像分类 数据加载与可视化篇所遇到的程序开发问题。

如果觉得大佬教程网站内容还不错,欢迎将大佬教程推荐给程序员好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。