您的位置:首页 > 健康 > 养生 > 免费查企业信息查询_微信营销软件升级版_微商软文推广平台_买友情链接有用吗

免费查企业信息查询_微信营销软件升级版_微商软文推广平台_买友情链接有用吗

2024/12/23 12:57:55 来源:https://blog.csdn.net/BH04250909/article/details/143368390  浏览:    关键词:免费查企业信息查询_微信营销软件升级版_微商软文推广平台_买友情链接有用吗
免费查企业信息查询_微信营销软件升级版_微商软文推广平台_买友情链接有用吗

1. 名词解释

FFN

  • FFN : Feedforward Neural Network,前馈神经网络
  • 馈神经网络是一种基本的神经网络架构,也称为多层感知器(Multilayer Perceptron,MLP)
  • FFN 一般主要是包括多个全连接层(FC)的网络,其中,全连接层间可以包含 : 激活层、BN层、Dropout 层。

MLP 与 FFN 的区别

在机器学习和深度学习中,MLP(多层感知机)和 FFN(前馈神经网络)在很大程度上可以视为同义词,都指代了一个具有多个层的前馈神经网络结构。

  • MLP(多层感知机)更偏向于表达网络结构(多个全连接层)
  • FFN(前馈神经网络)更偏向于表达数据以前馈的方式流动

MLP 和 FFN 通常指的是只包含全连接层 和激活函数的神经网络结构。这两者都是基本的前馈神经网络类型,没有包含卷积层或其他复杂的结构。

Logit

“Logit” 通常指的是神经网络中最后一个隐藏层的输出,经过激活函数之前的值。比如:

  • 对于二分类问题,logit 是指网络输出的未经过 sigmoid 函数处理的值
  • 对于多分类问题,logit 是指网络输出的未经过 softmax 函数处理的值

NLL

NLL 是 Negative Log-Likelihood(负对数似然)的缩写。
在深度学习中,特别是在分类问题中,NLL 经常与交叉熵损失(Cross-Entropy Loss)等价使用。

Anchor Box 与 Anchor Point

  • Anchor box 通常表示 一个包含位置和大小信息的四元组 ( x , y , w , h ) (x, y, w, h) (x,y,w,h),而 Anchor point 通常表示 一个二元组 ( x , y ) (x, y) (x,y)。 其中, x x x y y y表示框的中心坐标, w w w h h h表示框的宽度和高度。
  • Anchor box 是目标检测中用于定义目标位置和大小的一种方式。而 Anchor point 主要用于在图像上生成 anchor box 的位置,生成的 anchor box 会在 anchor point 的周围不同尺寸和宽高比的情况下进行缩放,形成一系列不同形状的框。

parameter efficient

参数效率高,指的是网络在达到良好性能的同时所使用的参数数量较少。

Deep Supervision

Deep Supervision 是一种训练策略,旨在提高网络的梯度流动,并促使网络更快地收敛,并且有助于缓解梯度消失问题。Deep Supervision 的核心思想是在网络的不同层中引入额外的监督信号,而不仅仅在最后一层输出进行监督训练。具体来说:Deep Supervision 会使用网络的中间层输出,计算出一部分损失函数,然后和网络最后一层的损失函数一起,对网络的参数进行优化。

DP 与 DDP

DP : DataParallel,数据并行
DDP :Distributed Data Parallel,分布式数据并行

感受野(Receptive Field)

1、介绍

感受野(receptive field)是卷积神经网络输出特征图上的像素点在原始图像上所能看到的(映射的)区域的大小,它决定了该像素对输入图像的感知范围(获取信息的范围)。较小的感受野可以捕捉到更细节的特征,而较大的感受野可以捕捉到更全局的特征。
在这里插入图片描述
如果连续进行 2次卷积操作,卷积核大小都为 3x3,stride=1, padding=0, 如下图,layer3上的每一个像素点在 layer1上的感受野 为 5x5
在这里插入图片描述

2、感受野计算公式

感受野计公式 : F ( i ) = ( F ( i + 1 ) − 1 ) × S t r i d e + K s i z e F(i)=(F(i+1)-1)\times Stride + Ksize F(i)=F(i+1)1×Stride+Ksize F i n = ( F o u t − 1 ) × S t r i d e + K s i z e F_{in}=(F_{out}-1)\times Stride + Ksize Fin=Fout1×Stride+Ksize
其中:

  • F ( i ) F(i) F(i) :在第 i i i层的感受野
  • S t r i d e Stride Stride:第 i i i层步距
  • K s i z e Ksize Ksize:第 i i i层卷积或池化的 kernel size

3、计算举例

求 :layer3 上的每个像素在 layer1 上的感受野。
在这里插入图片描述
1)先来计算 layer3 上的一个像素( F ( 3 ) = 1 F(3)=1 F(3)=1)在 layer2 上的感受野 :
F ( 2 ) = ( F ( 3 ) − 1 ) × S t r i d e + K s i z e = ( 1 − 1 ) × 2 + 2 = 2 F(2) = (F(3)-1) \times Stride + Ksize = (1 -1) \times 2 + 2 = 2 F(2)=(F(3)1)×Stride+Ksize=(11)×2+2=2

2)计算 layer3 上的一个像素( F ( 3 ) = 1 , F ( 2 ) = 2 F(3)=1, \; F(2)=2 F(3)=1F(2)=2 )在 layer1 上的感受野 :
F ( 1 ) = ( F ( 2 ) − 1 ) × S t r i d e + K s i z e = ( 2 − 1 ) × 2 + 3 = 5 F(1)=(F(2)-1)\times Stride + Ksize =(2 -1)\times 2 + 3 = 5 F(1)=(F(2)1)×Stride+Ksize=(21)×2+3=5

如果仅计算 layer2 上的一个像素( F(2)=1 )在 layer1 上的感受野 :
F ( 1 ) = ( F ( 2 ) − 1 ) × S t r i d e + K s i z e = ( 1 − 1 ) × 2 + 3 = 3 F(1)=(F(2)-1)\times Stride + Ksize = (1 -1)\times 2 + 3 = 3 F(1)=F(2)1×Stride+Ksize=11×2+3=3

2. tensor 相关

tensor 内部存储结构

1、数据区域和元数据

PyTorch 中的 tensor 内部结构通常包含了 数据区域(Storage) 和 元数据(Metadata) :

  • 数据区域 : 存储了 tensor 的实际数据,且数据被保存为连续的数组。比如: a = torch.tensor([[1, 2, 3], [4, 5, 6]]),它的数据在存储区的保存形式为 [1, 2, 3, 4, 5, 6]
  • 元数据 :包含了 tensor 的一些描述性信息,比如 : 尺寸(Size)、步长(Stride)、数据类型(Data Type) 等信息

占用内存的主要是 数据区域,且取决于 tensor 中元素的个数, 而元数据占用内存较少。
采用这种 【数据区域 + 元数据】 的数据存储方式,主要是因为深度学习的数据动辄成千上万,数据量巨大,所以采取这样的存储方式以节省内存
在这里插入图片描述


2、查看 tensor 的存储区数据: storage()

虽然 .storage() 方法即将被弃用,而是改用 .untyped_storage(),但为了笔记中展示方便,我们仍然使用 .storage() 方法。.untyped_storage() 方法的输出太长了,不方便截图放在笔记中。

a = torch.tensor([[1, 2, 3],[4, 5, 6]])print(a.storage())

在这里插入图片描述


3、查看 tensor 的步长: stride()

stride() : 在指定维度 (dim) 上,存储区中的数据元素,从一个元素跳到下一个元素所必须的步长

a = torch.randn(3, 2)
print(a.stride())  # (2, 1)

解读:
在这里插入图片描述
在第 0 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[1][0] ,需要经过 2个元素,步长是 2
在第 1 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[0][1], 需要经过 1个元素,步长是 1

4、查看 tensor 的偏移量:storage_offset()

表示 tensor 的第 0 个元素与真实存储区的第 0 个元素的偏移量

a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]   # tensor([2, 3, 4, 5])
c = a[3:]   # tensor([4, 5])
print(b.storage_offset())   # 1
print(c.storage_offset())   # 3
  • b 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 1
  • c 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 3

5、代码举例

  • 一般来说,一个 tensor 有着与之对应的 storage, storage 是在 data 之上封装的接口。

  • 不同 tensor 的元数据一般不同,但却可能使用相同的 storage。

  • data_ptr()

    • 返回的是张量数据 (storage 数据)存储的实际内存地址,确切来说是张量数据的起始内存地址。
    • data_ptr 中的 ptr 是 pointer(指针)的缩写,对应于 C 语言中的指针,因为 Python 的底层就是由 C 实现的
  • id(a)

    • 返回的是 a 在 Python 内存管理系统中的唯一标识符。虽然这个标识符通常与对象的内存地址有关,但它并不直接表示内存地址。

1)观察一

import torcha = torch.arange(0, 6)
print('a = {}\n'.format(a))
print('tensor a 存储区的数据内容 :{}\n'.format(a.storage()))
print('tensor a 相对于存储区数据的偏移量 :{}\n'.format(a.storage_offset()))print('*'*20, '\n')b = a.view(2,3)
print('b = {}\n'.format(b))
print('tensor b 存储区的数据内容 :{}\n'.format(b.storage()))
print('tensor b 相对于存储区数据的偏移量 :{}\n'.format(b.storage_offset()))

在这里插入图片描述
2)观察二

import torcha = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)print(a.data_ptr())   # 140623757700864
print(b.data_ptr())   # 140623757700864print(id(a))   # 4523755392
print(id(b))   # 4602540464

在这里插入图片描述

  • a.data_ptr()b.data_ptr() 一样,说明 tensor a 和 tensor b 共享相同的存储区,即,它们指向相同的底层数据存储对象。
  • id(a)id(b) 不一样,是因为虽然 a 和b 共享storage 数据,但是 它们 有不同的 size 或者 strides 、 storage_offset 等其他属性

3)观察三

import torcha = torch.tensor([1, 2, 3, 4, 5, 6])
c = a[2:]print(c.storage())print('\n', '*'*20, '\n')print('tensor a 首元素的内存地址 : {}'.format(a.data_ptr()))
print('tensor c 首元素的内存地址 : {}'.format(c.data_ptr()))
print(c.data_ptr() - a.data_ptr())print('\n', '*'*20, '\n')c[0] = -100
print(a)

在这里插入图片描述

  • data_ptr() 返回 tensor 首元素的内存地址
  • c 和 a 的首元素内存地址相差 16,每个元素占用 8 个字节(LongStorage), 也就是首元素相差两个元素
  • 改变 c 的首元素, a 对应位置的元素值也被改变

6、总结

  1. 由上可知,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的元数据,比如修改 tensor 的 offset 、stride 和 size ,这种做法更节省内存,同时提升了处理速度。
  2. 有些操作会导致 tensor 不连续,这时需要调用 torch.contiguous 方法将其变成连续的数据,该方法会复制数据到新的内存,不再与原来的数据共享 storage。

3.Dataset 与 DataLoader

  • Dataset 作用 :
    • 定义和管理如何获取单个数据样本及其标签
    • 包含 加载数据/读取数据、预处理数据、图像增强 等一系列操作
    • 返回 单个数据样本 的处理结果
  • DataLoader 作用 :
    • 指定数据读取规则,一般通过 sampler 指定
    • 指定 batch数据的打包规则,通过 collate_fn 指定
    • 每次迭代,返回的是 一个batch 的数据

在这里插入图片描述

生成 Dataset 方式一 :自定义Dataset

所谓的 自定义 dataset ,即:我们自己去写一个 Dataset 类 :

  • 一般需要继承 torch.utils.data.Dataset
    • 继承 torch.utils.data.Dataset 主要是为了与 DataLoader 保持兼容,确保数据集遵循 DataLoader 的接口标准,方便后续使用 PyTorch 提供的工具,比如 :批量加载、打乱数据、并行处理等功能
  • 并且满足和 DataLoader 进行交互的规范 :
    • DataLoader 会调用 Dataset 的 len() 和 getitem() 方法,所以自定义 Dataset 类必须实现这两个方法,如此才能保证 DataLoader 可以正确地加载和操作你的数据集

1、自定义 Dataset 的三个重要方法

创建自定义 Dataset 时,必须实现的3个方法 :init()、len()、 getitem()。
这些方法定义了数据集的基本结构和行为,也是 DataLoader 可以正确的从 Dataset 中读取数据的基础。
1)init 方法

  • 参数: 根据需要传递一些参数,例如文件路径、数据转换等。
  • 作用: 可以在这里进行一些初始化工作,例如:设置文件路径、定义数据转换transforms 等。
def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)self.train = train

2)len 方法

def __len__(self):return len(self.file_list)
  • 返回值: 需返回数据集中的样本的总数。
  • 作用:
    • 方便通过调用 len(dataset) 来获取数据量,其中 dataset 为 Dataset 对象
    • Dataloader 会用它 和 batch_size 一起来计算一个epoch 要迭代多少个 steps: s t e p s = l e n ( d a t a s e t ) b a t c h s i z e steps = \frac{len(dataset)}{batchsize} steps=batchsizelen(dataset)

3)getitem 方法

def __getitem__(self, idx):img_name = os.path.join(self.data_folder, self.file_list[idx])original_image = Image.open(img_name)label = img_name.split('_')[-1].split('.')[0]if self.train:image = self.transform(original_image)else:image = self.transform(original_image)return image, label
  • 参数: index 是样本的索引。
  • 返回值: 返回数据集中索引指定的样本。通常是一个包含输入数据和对应标签的元组。这里可以根据自己的需求,进行自定义。
  • 作用: 根据给定的索引返回数据集中的一个样本。这是用于获取数据集中单个样本的方法。
    比如,可以通过 dataset[0] 来获取 dataset 中的索引为 0 的样本

以上这三个方法一起定义了 PyTorch 中的 dataset 类,并支持使用 torch.utils.data.DataLoader 来加载数据并进行训练。


2、使用举例

用 CIRFAR-100 数据集生成 Dataset
在这里插入图片描述

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import osclass CustomDataset(Dataset):def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)self.train = traindef __getitem__(self, idx):img_name = os.path.join(self.data_folder, self.file_list[idx])original_image = Image.open(img_name)label = img_name.split('_')[-1].split('.')[0]if self.train:image = self.transform(original_image)else:image = self.transform(original_image)return image, labeldef __len__(self):return len(self.file_list)images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())print(len(dataset))# 通过 dataset 对象,获取索引为 0 的样本
sample_image, sample_label = dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Label:", sample_label)

输出:
在这里插入图片描述

生成 Dataset 方式二 :torchvision.datasets 模块

1、pytorch 官方支持下载的数据集

官网地址 : 点击查看
在这里插入图片描述
注 :

  • 对于一部分数据集,提供下载功能
  • 对于一部分数据集,不提供下载功能 (具体情况取决于数据集的来源和许可协议)

2、torchvision.datasets 模块

以获取 MNIST 数据集为例 (pytorch官方文档地址 : 点击查看)
MNIST 全称:mixed national institute of standards and technology database

train_dataset = torchvision.datasets.MNIST(root,    train=True,               transform=None,  target_transform= None  download=True)

参数 :

  • root :数据集存放的路径
  • train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt。 默认是True
  • transform:一系列作用在PIL图片上的转换操作
  • download:是否下载数据集,默认为 False
    • 若设置 download=True
      • root 目录下没有该数据集,数据集将会被下载到 root 指定的位置。
      • root 目录下已经存在该数据集,则不会重新下载,而是会直接使用已存在的数据,以节省时间
    • 若设置 download=False,程序将会在 root 指定的位置查找数据集,如果数据集不存在,则会抛出错误。

3、举例 1:torchvision.datasets.MNIST

  • 因为是单通道,所以 transforms.Normalize 的均值和标准差 仅指定了一个值
  • 记得把数据集的下载地址换掉,换成你想要它下载到的位置
import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as pltbatch_size = 5my_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # mean=[0.485, 0.456, 0.406]std=[0.5])])  # std=[0.229, 0.224, 0.225]train_dataset = torchvision.datasets.MNIST(root="./",train=True,transform=my_transform,download=True)val_dataset = torchvision.datasets.MNIST(root="./",train=False,transform=my_transform,download=True)train_loader = data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)val_loader = data.DataLoader(val_dataset,batch_size=batch_size,shuffle=True)print(len(train_dataset))
print(len(val_dataset))image, label = next(iter(train_loader))
print(image.shape)
print(label)for i in range(batch_size):plt.subplot(1, batch_size, i + 1)plt.title(label[i].item())plt.axis("off")plt.imshow(image[i].permute(1, 2, 0))plt.show()

输出:
在这里插入图片描述


4、举例 2:torchvision.datasets.CocoDetection

官方文档 : 点击查看

torchvision.datasets.CocoDetection 不支持 COCO 数据集下载
在使用 torchvision.datasets.CocoDetection 之前,需要确保已经下载并淮备好COCO数
据集的图像和标注文件。然后使用 torchvision.datasets.CocoDetection 类来加载 COCO数据集。

torchvision.datasets.CocoDetection(root, annFile, transform=None, target_transform=None, transforms=None)

参数 :

  • root : 指定图片地址 (本地已经下载下来的图像地址)
  • annFile : 指定标注文件地址( 本地已经下载下来的标注文件地址)
  • transform : 图像处理 (用于PIL)
  • target_transform : 标注处理
  • transforms : 图像和标注的处理

使用举例:

  • 记得把数据集的下载地址换掉,换成你的 COCO数据集地址
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import randomdef collate_fn_coco(batch):return tuple(zip(*batch))coco_det = datasets.CocoDetection(root="./COCO2017/train2017",annFile="./COCO2017/annotations/instances_train 2017.json")sampler = torch.utils.data.SequentialSampler(coco_det)  # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,batch_sampler=batch_sampler,collate_fn=collate_fn_coco)# 可视化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img,  gts_one_img = imgs[0], gts[0]bboxes = []
ids = []
for gt in gts_one_img:bboxes.append([gt['bbox'][0],gt['bbox'][1],gt['bbox'][2],gt['bbox'][3]])ids.append(gt['category_id'])fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):x = int(box[0])y = int(box[1])w = int(box[2])h = int(box[3])rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')ax.add_patch(rect)ax.text(x, y, id, backgroundcolor="r")plt.axis("off")
plt.imshow(img)
plt.show()

输出效果:
在这里插入图片描述

DataLoader

1、torch.utils.data.DataLoader

官方文档 :点击查看

from torch.utils.data import DataLoaderdata_loader = DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False,timeout=0)

参数:

  • dataset : 加载数据的数据集
  • batch size : 每批返回的数据量,默认值是 1
  • shuffle:是否在每个 epoch 内将数据打乱顺序。默认值为False
  • sampler :从数据集中提取的样本序列。可以用来自定义样本的采样策路。默认值为None
  • batch_sampler :与sampler类似,但是一次返回一个 batch的索引,用于自定义 batch。它与 与 batch size、shuffle、sampler 和 drop last 互斥
  • num workers : 用于数据加载的子进程数。0表示主进程加载。默认值为0
  • collate_fn: 用于指定如何组合样本数据。如果为None,那么将默认使用默认的组合方法
  • drop_last : 如果数据集的大小不能被 batch _size 整除,那么是否丢弃最后一个数据批次。默认值为 False
  • pin_memory : 将数据固定在内存的锁页内存中,加速数据读取的速度。默认值为False.
  • timeout : workers :等待 collect 一个 batch 的数据的超时时间。默认为 0,表示一直等待

2、常用参数图示

dataset 对 Dataloader 有 2个作用 :

  • 通过 dataset 的 length 方法,dataloader 可以知道数据量,从而根据数据量生成相应的索引列表
  • dataloader 会将索引,传给 dataset 的 getitem 方法,通过 getitem 方法对数据进行处理,并返回处理好的数据

在这里插入图片描述


3、Dataset 与 Dataloader 的内部交互细节 举例

在这里插入图片描述

num_workers 与 pin_memory

1、参数 num_workers

参数 num_workers 参数用于指定 加载数据的子进程的数量

  • num_workers=0 :(默认值) 表示只有主进程去加载 batch数据,这个可能会是一个瓶颈。
  • num_workers=1 :表示只有一个子进程加载数据,主进程不参与,这仍可能导致速度慢。
  • num_workers>0 :表示指定数量的子进程并行加载数据,且主进程不参与。

增加num_workers可以提高加载速度,但也会增加 CPU 和 内存的使用。
通常建议将 num_workers 参数设置为等于或小于 CPU 核心数,以有效平衡数据加载效率和系统资源占用率。
进程之间是动态调度的,谁先做完一个样本:

batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])   # number of workers
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,num_workers=nw,shuffle=True,pin_memory=True,collate_fn=collate_fn)

2、数据加载过程是如何并行的

一个进程处理一个 batch 的数据,假设设置 num_workers=2 ,则 进程1 处理一个batch 的数据,进程2 处理另一个batch 的数据

在这里插入图片描述
并行工作流程 :

  • 初始化:创建 DataLoader 实例时,通过参数 num_workers 指定并行加载的子进程数量
  • 子进程加载数据:子进程独立于主进程运行,每个子进程的拿着一个batch 的索引,并行的到 dataset 的 getitem 中预处理数据
  • 数据准备:处理好的数据,放入缓冲区以备主进程请求
  • 数据请求:主进程在 for 循环中请求下一个 batch
  • 数据传输:主进程请求数据时,从缓冲区获取已经准备好的 batch
  • 循环迭代:主进程不断请求数据,子进程并行的处理后续的 batch 数据

3、pin_memory

  • 若设置 pin_memory=True,数据会被加载到CPU的锁定内存中,从而提高数据从 CPU 到 GPU 的传输效率

这是因为锁定的内存(pinned memory)可以更快地被复制到GPU,因为它是连续的,并且已经准备好被传输。

  • 若设置 pin_memory=False ,则数据是被存放在分页内存(pageable memory)中,当我们想要把数据从 cpu 移动到 gpu 上 (执行 .to('cuda') 的时候), 需要先将数据从分页内存中 移动到锁页内存中,然后再传输到 GPU 上

所以,设置 pin_memory=True ,节省的是 将数据从 分页内存移动到锁页内存中 的这段时间

如果你的训练完全在CPU上进行,不涉及GPU,那就没有必要设置 pin_memory=True。因为在这种情况下,数据不需要被传输到GPU,因此不需要使用锁定内存来加速这一过程。可以将 pin_memory 设置为 False,以简化内存管理。
在这里插入图片描述

sampler 与 batch_sampler

1、sampler

torch.utils.data.DataLoader 的参数 sampler 接收的通常是一个实现了 Sampler 接口的对象,比如 :

sampler = SequentialSampler(dataset)   # 使用 SequentialSampler
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

通过 sampler 对象来控制数据集的索引顺序,从而影响数据从数据集中的抽取方式
1)pytorch 提供的,可以直接使用的几种 sampler

# 顺序抽样,按照数据集的顺序逐个抽取样本
torch.utils.data.sampler.SequentialSampler()# 随机抽样,数据集中的样本以随机顺序被抽取
torch.utils.data.sampler.RandomSampler()# 从指定的样本索引子集内进行随机抽样
torch.utils.data.sampler.SubsetRandomSampler()# 根据样本的权重随机抽样,不同样本有不同的抽样概率
torch.utils.data.sampler.WeightedRandomSampler()

2)可以自定义 sampler,比如以下是 yolov5 中自定义的 sampler :
在这里插入图片描述
参数 sampler 有一部分功能,是和 参数 shuffle 是重叠的:

  • SequentialSampler 效果等价于 shuffle=False
  • RandomSampler 效果等价于 shuffle=Ture
    Pytorch 提供 sampler 参数,主要是为提升灵活性,支持用户更灵活地设计数据加载的方式

下面我们主要介绍 SequentialSampler 和 RandomSampler, 只要大家通过 SequentialSampler 、RandomSampler 掌握了 sampler 的工作原理,便可以愉快的自定义的去设计 sampler 了。


1)顺序采样 SequentialSampler

作用 :接收一个 Dataset 对象,输出数据包中样本量的顺序索引

举例 1

import torch.utils.data.sampler as samplerdata = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)for index in seq_sampler:print("index: {}".format(index))

在这里插入图片描述
相关源码

class SequentialSampler(Sampler):data_source: Sizeddef __init__(self, data_source: Sized) -> None:self.data_source = data_sourcedef __iter__(self) -> Iterator[int]:return iter(range(len(self.data_source)))def __len__(self) -> int:return len(self.data_source)
  • init 接收参数:Dataset 对象
  • iter 返回一个可迭代对象(返回的是索引值),因为 SequentialSampler 是顺序采样,所以返回的索引是顺序数值序列
  • len 返回 dataset 中数据个数

举例 2

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSamplerclass myDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)# 使用 SequentialSampler
sampler = SequentialSampler(dataset)# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)

在这里插入图片描述


2)随机采样 RandomSampler

作用 :接收一个 Dataset 对象,输出数据包中样本量的随机索引 (可指定是否可重复)。

举例 1

import torch.utils.data.sampler as samplerdata = list([17, 22, 3, 41, 8])
seq_sampler = sampler.RandomSampler(data_source=data)for index in seq_sampler:print("index: {}".format(index))

在这里插入图片描述
相关源码 (删减版本)

class RandomSampler(Sampler):def __init__(self, data_source, replacement=False, num_samples=None):self.data_source = data_sourceself.replacement = replacementself._num_samples = num_samplesdef num_samples(self):if self._num_samples is None:return len(self.data_source)return self._num_samplesdef __len__(self):return self.num_samplesdef __iter__(self):n = len(self.data_source)if self.replacement:# 生成的随机数是可能重复的return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())# 生成的随机数是不重复的return iter(torch.randperm(n).tolist())

查看 torch.randperm() 的使用 :

  • init 参数 :
    • data_source (Dataset): 采样的 Dataset 对象
    • replacement (bool): 如果为 True,则抽取的样本是有放回的。默认为 False
    • num_samples (int): 抽取样本的数量,默认是len(dataset)。当 replacement 是 True 时,应被实例化
  • iter 返回一个可迭代对象(返回的是索引),因为 RandomSampler 是随机采样,所以返回的索引是随机的数值序列 (当 replacement=False 时,生成的排列是无重复的)
  • len 返回 dataset 中样本量

举例 2

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSamplerclass myDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)# 使用 SequentialSampler
sampler = RandomSampler(dataset)# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)

在这里插入图片描述


2、sampler 与 shuffle 的互斥

  • 参数 sampler 与 参数 shuffle 是互斥的,不要同时使用 sampler 和 shuffle
  • 因为 shuffle 的默认值为 False,所以代码会兼容 shuffle 等于默认值 False 的情况,即 :
    • 当同时设置了 shuffle 与 sampler,且 shuffle=True,会报错
    • 当同时设置了 shuffle 与 sampler,且 shuffle=False,具体逻辑按照 sampler

3、批采样 BatchSampler

官方文档 :

https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler

torch.utils.data.DataLoaderde 的参数 batch_sample, 接收的一般是 torch.utils.data.BatchSampler 对象,
torch.utils.data.BatchSampler 的作用 : 包装另一个采样器,生成一个小批量索引采样器

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

举例 1

import torch.utils.data.sampler as sampler
data = list([17, 22, 3, 41, 8])seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 2, False )for index in batch_sampler:print(index)

在这里插入图片描述
相关源码 (删减版本)

class BatchSampler(Sampler):def __init__(self, sampler, batch_size, drop_last):、self.sampler = samplerself.batch_size = batch_sizeself.drop_last = drop_lastdef __iter__(self):batch = []for idx in self.sampler:batch.append(idx)# 如果采样个数和batch_size相等则本次采样完成if len(batch) == self.batch_size:yield batchbatch = []# for 结束后在不需要剔除不足batch_size的采样个数时返回当前batch        if len(batch) > 0 and not self.drop_last:yield batchdef __len__(self):# 在不进行剔除时,数据的长度就是采样器索引的长度if self.drop_last:return len(self.sampler) // self.batch_sizeelse:return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  • 参数 :
    • sampler : 其他采样器实例
    • batch_size :批量大小
    • drop_last :为 “True”时,如果最后一个batch 采样得到的数据个数小于batch_size,则抛弃最后一个batch的数据

举例 2

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, BatchSamplerclass myDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]# 示例数据 :# 生成 0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)# 使用 SequentialSampler 顺序采样
sequential_sampler = SequentialSampler(dataset)# 使用 BatchSampler 将 SequentialSampler 和 batch_size 结合
batch_sampler = BatchSampler(sequential_sampler, batch_size=8, drop_last=False)# 创建 DataLoader,使用 BatchSampler
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)# 使用 DataLoader 迭代数据
for data in dataloader:print(data)

4、BatchSampler 与 其他参数的互斥

如果你在 DataLoader(dataset, batch_sampler=batch_sampler) 中指定了参数 batch_sampler, 那么就不能再指定参数 batch_size、shuffle、sampler、和 drop_last 了,他们互斥。

因为:

  • 你在生成torch.utils.data.sampler.BatchSampler() 的时候,就已经制定过 batch_size、sampler、和 drop_last 这些参数了,
  • batch_sampler 与 shuffle 作用一致,所以也互斥

比如,如下代码就会报错,因为在 DataLoader 中重复指定了 batch_size

random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader = DataLoader(dataset, batch_size=2, batch_sampler=batch_sampler)

在这里插入图片描述

重写 collate_fn 实例

1、collate_fn 函数作用

在使用 torch.utils.data.dataset 时,参数 collate_fn 接受一个函数,该函数的函数名通常就定义为: collate_fn
collate_fn 函数的作用 :将多个 经过 dataset.getitem() 处理好的 样本数据,组合成一个 batch 的数据。
在这里插入图片描述
相关代码见最后【4、附】部分


2、默认 collate_fn 函数

简易实现版本 :

def default_collate(batch):# 检查样本类型并处理if isinstance(batch[0], torch.Tensor):return torch.stack(batch, dim=0)elif isinstance(batch[0], (list, tuple)):return [default_collate(samples) for samples in zip(*batch)]elif isinstance(batch[0], dict):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], int):return torch.tensor(batch)  # 将 int 转换为 Tensorraise TypeError(f"Unsupported type: {type(batch[0])}")

3、自定义 collate_fn 函数

1)常见场景

举例 :一个 batch 中的 多张图片,经过 dataset.getitem() 方法,得到的图像输出尺寸不一样 (比如,可能因为 图像增强 使用 的 transforms ,设计的 最后一步处理方式是范围内的随机裁剪)

因为 网络要求输入数据的尺寸形式为 (batch_size, channel, high,width), 为了将多张图像数据打包成一个batch 的数据形式,需要将图像加上padding,保证所有图像尺寸一致,进而组成 batch 的数据形式

在这里插入图片描述
collate_fn 函数中需要处理的内容为 :

  • 对比 batch 中,所有图像的宽和高,找到最长的宽度 和 最长的高度
  • 将所有的图像都 padding 到最长的宽度 和 最长的高度
  • 处理的得到 mask 数据,用于标注 : 哪些位置是 有效像素,哪些位置是 padding
  • 将所有数据处理成 batch 的格式,进行返回
2)相关代码实现

相关代码 : 点击跳转

data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,collate_fn=utils.collate_fn, num_workers=args.num_workers,pin_memory=True)data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,pin_memory=True)

相关代码 :点击跳转

def collate_fn(batch):batch = list(zip(*batch))batch[0] = nested_tensor_from_tensor_list(batch[0])return tuple(batch)def _max_by_axis(the_list):# type: (List[List[int]]) -> List[int]maxes = the_list[0]for sublist in the_list[1:]:for index, item in enumerate(sublist):maxes[index] = max(maxes[index], item)return maxesdef nested_tensor_from_tensor_list(tensor_list: List[Tensor]):# TODO make this more generalif tensor_list[0].ndim == 3:# TODO make it support different-sized imagesmax_size = _max_by_axis([list(img.shape) for img in tensor_list])# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))batch_shape = [len(tensor_list)] + max_sizeb, c, h, w = batch_shapedtype = tensor_list[0].dtypedevice = tensor_list[0].devicetensor = torch.zeros(batch_shape, dtype=dtype, device=device)mask = torch.ones((b, h, w), dtype=torch.bool, device=device)for img, pad_img, m in zip(tensor_list, tensor, mask):pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)m[: img.shape[1], :img.shape[2]] = Falseelse:raise ValueError('not supported')return NestedTensor(tensor, mask)

4、附

在这里插入图片描述
注 :更换 cifar-100 在你本地的路径

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import ostorch.manual_seed(121)
torch.cuda.manual_seed(121)label_dict = {'apple': 0,'aquarium_fish': 1,'baby': 2,'bear': 3,'beaver': 4,'bed': 5,'bee': 6,'beetle': 7,'bicycle': 8,'bottle': 9,'bowl': 10,'boy': 11,'bridge': 12,'bus': 13,'butterfly': 14,'camel': 15,'can': 16,'castle': 17,'caterpillar': 18,'cattle': 19,'chair': 20,'chimpanzee': 21,'clock': 22,'cloud': 23,'cockroach': 24,'couch': 25,'crab': 26,'crocodile': 27,'cup': 28,'dinosaur': 29,'dolphin': 30,'elephant': 31,'flatfish': 32,'forest': 33,'fox': 34,'girl': 35,'hamster': 36,'house': 37,'kangaroo': 38,'keyboard': 39,'lamp': 40,'lawn_mower': 41,'leopard': 42,'lion': 43,'lizard': 44,'lobster': 45,'man': 46,'maple_tree': 47,'motorcycle': 48,'mountain': 49,'mouse': 50,'mushroom': 51,'oak_tree': 52,'orange': 53,'orchid': 54,'otter': 55,'palm_tree': 56,'pear': 57,'pickup_truck': 58,'pine_tree': 59,'plain': 60,'plate': 61,'poppy': 62,'porcupine': 63,'possum': 64,'rabbit': 65,'raccoon': 66,'ray': 67,'road': 68,'rocket': 69,'rose': 70,'sea': 71,'seal': 72,'shark': 73,'shrew': 74,'skunk': 75,'skyscraper': 76,'snail': 77,'snake': 78,'spider': 79,'squirrel': 80,'streetcar': 81,'sunflower': 82,'sweet_pepper': 83,'table': 84,'tank': 85,'telephone': 86,'television': 87,'tiger': 88,'tractor': 89,'train': 90,'trout': 91,'tulip': 92,'turtle': 93,'wardrobe': 94,'whale': 95,'willow_tree': 96,'wolf': 97,'woman': 98,'worm': 99
}def default_collate(batch):# 检查样本类型并处理if isinstance(batch[0], torch.Tensor):return torch.stack(batch, dim=0)elif isinstance(batch[0], (list, tuple)):return [default_collate(samples) for samples in zip(*batch)]elif isinstance(batch[0], dict):return {key: default_collate([d[key] for d in batch]) for key in batch[0]}elif isinstance(batch[0], int):return torch.tensor(batch)  # 将 int 转换为 Tensorraise TypeError(f"Unsupported type: {type(batch[0])}")class CustomDataset(Dataset):def __init__(self, data_folder, train, transform=None):self.data_folder = data_folderself.transform = transformself.file_list = os.listdir(data_folder)self.train = traindef __getitem__(self, idx):img_name = os.path.join(self.data_folder, self.file_list[idx])original_image = Image.open(img_name)label_name = img_name.split('_')[-1].split('.')[0]label_idx = label_dict[label_name]if self.train:image = self.transform(original_image)else:image = self.transform(original_image)return image, label_idxdef __len__(self):return len(self.file_list)images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())data_loader = DataLoader(dataset,batch_size=2,shuffle=True,collate_fn=default_collate)for data in data_loader:image, label = data

RandomSampler 与 shuffle=True 的区别

效果完全没有区别,只是实现方式不一样。

  • shuffle=True 的实现方式: 在每个 epoch 开始时将整个数据集打乱,然后按照打乱后的顺序划分 batch。再按照batch_size 个数依次提取数据
  • sampler.BatchSampler(random_sampler) 的实现方式:(数据不会打乱)
    • step 1、RandomSampler 会生成随机的索引。
    • step 2、BatchSampler 根据上面随机出来的索引生成 batch 组。
    • step 3、拿着每个batch 组的索引去取 数据

相同点:

  1. 每个epoch 都会重新打乱
  2. 都不会重复采样,除非你通过参数指定了可以重复采样

其他说明:

  1. shuffle=True 的性能更高一些,而 BatchSampler灵活性更高,因为你可以通过 BatchSampler 设计更复杂的采样方式
  2. 在 Dataloader 中使用 batch_sampler 的常见目的之一,是为了兼容 DistributedSampler,比如:
if args.distributed:sampler_train = DistributedSampler(dataset_train)sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:sampler_train = torch.utils.data.RandomSampler(dataset_train)sampler_val = torch.utils.data.SequentialSampler(dataset_val)batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)data_loader_train = DataLoader(dataset_train,batch_sampler=batch_sampler_train,collate_fn=utils.collate_fn,)
data_loader_val = DataLoader(dataset_val,args.batch_size,sampler=sampler_val,drop_last=False,collate_fn=utils.collate_fn,)

跑个小例子,看一下 :

import torch
import torch.utils.data.sampler as sampler
from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):def __init__(self):self.data = [1, 2, 3, 4, 5]def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]dataset = MyDataset()# =============================================
random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader1 = DataLoader(dataset, batch_sampler=batch_sampler)for epoch in range(3):for index, data in enumerate(dataloader1):print(index, data)
print('*'*30)# =============================================
dataloader2 = DataLoader(dataset, batch_size=2, shuffle=True)for epoch in range(3):for index, data in enumerate(dataloader2):print(index, data)

在这里插入图片描述

数据处理&数据增强

数据预处理 和 数据增强,我们一般都是使用 torchvision.transforms 模块来完成的。
我敢说,当你掌握了 torchvision.transforms 的使用方法之后,一定在数据预处理 和 数据增强 方面毫无压力。
官网地址 :

https://pytorch.org/vision/stable/transforms.html#others

在这里插入图片描述


简单使用举例:
1、训练阶段

from torchvision.transforms import transformsmy_transform = transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

2、推理阶段

from torchvision.transforms import transformsmy_transform = transforms.Compose([transforms.Resize(original_size*1.143)transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

注:

  1. 操作顺序 : 几何变换 / 颜色变换 —>> ToTensor() —>> Normalize()
  2. ToTensor()Normalize() 是数据处理的最后2步
  3. 经过 Normalize()处理后,得到的数据,一般可直接输入到模型中使用

1、图像尺寸变换 与 裁剪

1)transforms.Resize

官方文档 : 点击跳转

torchvision.transforms.Resize(size, interpolation=InterpolationMode.BILINEAR, max_size=None)

作用:将图像按照指定的插值方式,resize到指定的尺寸。
参数:

  • size: 输出的图像尺寸。可以是元组 (h, w) ,也可以是单个整数。
    • 如果 size 是元组,则输出大小将分别匹配 h, w 的大小
    • 如果 size 是整数,则图像较小的边将被resize 到此数字,并保持宽高比
  • interpolation: 选用如下插值方法将图像 resize 到输出尺寸
    • PIL.Image.NEAREST 最近邻差值
    • PIL.Image.BILINEAR 双线性差值(默认)
    • PIL.Image.BICUBIC 双三次差值
  • max_size :输出图像的较长边的最大值。仅当 size 为单个整数时才支持此功能。如果图像的较长边在根据 size 缩放后大于 max_size,则 size 将被覆盖,使较长边等于 max_size,这时较短边会小于 size。
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as pltoriginal_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)img_1 = transforms.Resize(1500, max_size=None)(original_img)
print(img_1.size)(2706, 1500)   # (2706, 1500)img_2 = transforms.Resize((1500, 1500))(original_img)
print(img_2.size)   # (1500, 1500)img_3 = transforms.Resize(1500, max_size=1600)(original_img)
print(img_3.size)   # (1600, 886)plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)plt.show()

在这里插入图片描述


2)transforms.CenterCrop

官方文档 : 点击跳转

功能:从图片中心裁剪出尺寸为 size 的图片
参数:

  • size: 所需裁剪的图片尺寸,即输出图像尺寸

注意:

  • 若切正方形,transforms.CenterCrop(100) transforms.CenterCrop((100, 100)),两种写法,效果一样
  • 如果设置的输出的尺寸 大于原图像尺寸,则会在四周补 padding,padding 颜色为黑色(像素值为0)

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as pltoriginal_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)img_1 = transforms.CenterCrop(1500)(original_img)
img_2 = transforms.CenterCrop((1500, 1500))(original_img)
img_3 = transforms.CenterCrop((3000, 3000))(original_img)plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)plt.show()

在这里插入图片描述


3)transforms.RandomCrop

官方文档 : 点击跳转
功能:

  • 从图片中随机裁剪出尺寸为 size 的图片
  • 如果设置了参数 padding,先添加 padding,再从padding后的图像中随机裁剪出大小为size的图片
    参数:
  • size :所需裁剪的图片尺寸,即输出图像尺寸
  • padding : 设置填充大小
    • padding值形式式为 a 时,上下左右均填充 a 个像素
    • padding值形式式为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
    • padding值形式式为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
  • pad_if_needed :当原图像尺寸小于设置的输出图像尺寸(由参数size指定),是否填充,默认为 False
  • padding_mode :若 pad_if_needed设置为 True,则此参数起作用, 默认值为 “constant”
    • "constant" : 像素值由参数 fill 指定 (默认填充黑色,像素值为0)
    • "edge" : padding 的像素值 为图像边缘像素值
    • "reflect" : 镜像填充,最后一个像素不镜像。([1,2,3,4] --> [3,2,1,2,3,4,3,2])
    • "symmetric" : 镜像填充,最后一个像素也镜像。([1,2,3,4] -->[2,1,1,2,3,4,4,3])
  • fill :指定填充像素值,当 padding_mode 为 constant 时起作用,默认填充黑色,像素值为0

注意:

  • 同时指定参数padding_mode 和 参数fill 时,若 padding_mode 值不为 "constant" ,则 参数fill不起作用。
  • 若指定的输出图像尺寸size 大于输入图像尺寸,并且指定参数 pad_if_needed= False,则会报错类似如下
    在这里插入图片描述

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as pltoriginal_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)img_1 = transforms.RandomCrop(1500, padding=500)(original_img)
img_2 = transforms.RandomCrop(3000, pad_if_needed=True, fill=(255, 0, 0))(original_img)
img_3 = transforms.RandomCrop(3000, pad_if_needed=True, padding_mode="symmetric")(original_img)plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)plt.show()

在这里插入图片描述

4)transforms.RandomResizedCrop

官方文档 : 点击跳转

torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=InterpolationMode.BILINEAR)

功能:

  • Step 1 : 将图像进行随机裁剪,裁剪出的图像需满足:
    • 裁剪后的图像面积 占原图像面积的比例 在指定的范围内
    • 裁剪后的图像高宽比 在指定范围内
  • Step 2 :将 Step 1 得到的图像通过指定的方式,进行缩放
    参数:
    • size: 输出的图像尺寸
    • scale: 随机缩放面积比例,默认随机选取 (0.08, 1) 之间的一个数
    • ratio: 随机长宽比,默认随机选取 (0.75, 1.33333 ) 之间的一个数。超过这个比例范围会有明显的失真
    • interpolation: 选用如下插值方法将图像 resize 到输出尺寸
      • PIL.Image.NEAREST 最近邻差值
      • PIL.Image.BILINEAR 双线性差值(默认)
      • PIL.Image.BICUBIC 双三次差值

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as pltoriginal_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)img = transforms.RandomResizedCrop(1500)(original_img)plt.subplot(121)
plt.imshow(original_img)plt.subplot(122)
plt.imshow(img)plt.show()

在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com