当前位置: 首页> 健康> 母婴 > 深度学习--数据处理dataloader介绍及代码分析

深度学习--数据处理dataloader介绍及代码分析

时间:2025/7/13 4:51:20来源:https://blog.csdn.net/weixin_42750325/article/details/141001171 浏览次数:0次

dataloader介绍

  • dataloader
    • 概述
    • collate_fn
      • 主要作用
      • 在代码中的使用
  • 代码详解
    • 代码解释
      • __init__函数
      • collate_fn
      • 详细说明
  • 完整代码

dataloader

概述

参考博客
DataLoader是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。
DataLoader是一个用于批量加载数据的工具,它可以将数据集分成多个小批量(mini-batch),并逐个加载,以适应模型训练的需要。
DataLoader主要用于两个关键任务:数据加载和批次处理

  • 数据加载:DataLoader可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动将数据集划分为小批次,从而减小内存需求,确保数据的高效加载。
  • 数据批次处理:每个批次由多个样本组成,可以并行地进行数据预处理和数据增强。这有助于提高模型训练的效率,同时确保每个批次的数据都经过适当的处理。

collate_fn

collate_fn 是一个自定义函数,用于在 PyTorch 的 DataLoader 中定义如何将单个样本组合成一个批次(batch)。具体来说,collate_fn 函数会在每次从 DataLoader 中取出一个批次的数据时被调用,用于对数据进行整理和转换。

主要作用

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成期望的数据格式。
将一个批次的数据样本整理成适合模型输入的格式,特别是将数据转换为 PyTorch 张量(Tensor),以便于后续的模型训练和推理。

  • 自定义数据堆叠:将单个样本组合成一个批次,处理数据的不同形状或类型。
  • 数据转换:在批次数据组成之前进行必要的转换操作,例如数据类型转换、数据增强等。

在代码中的使用

在本代码中,unet_dataset_collate 函数就是一个 collate_fn 函数。它的作用是将一个批次的数据样本(图像、PNG 数据和分割标签)整理成适合模型输入的格式。具体步骤包括将数据从列表转换为 NumPy 数组,再转换为 PyTorch 张量。

代码详解

# DataLoader中collate_fn使用
def unet_dataset_collate(batch):images      = []pngs        = []seg_labels  = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels

这段代码定义了一个名为 unet_dataset_collate 的函数,用于在 PyTorch 的 DataLoader 中自定义批处理方式。函数将一个批次的数据样本(batch)转换为适合模型输入的格式。

代码解释

__init__函数

在 DataLoader 中,init 函数的主要作用是初始化数据集对象,并为后续的数据加载和处理做好准备。
UnetDataset 类的 init 函数在 DataLoader 中的作用包括:

  • 数据集初始化:通过传入的参数(如 annotation_lines、input_shape 等)初始化数据集对象,使其包含所有必要的信息。
  • 数据预处理:在初始化过程中,可以对数据进行预处理,如归一化、裁剪等,以便后续的模型训练。
  • 数据分割:将数据集分割成训练集和验证集(通过 train 参数),以便在训练过程中进行模型评估。
  • 路径管理:通过 dataset_path 参数指定数据集的存储路径,方便数据的加载和管理。
# UnetDataset 类的初始化方法,接受五个参数:annotation_lines、input_shape、num_classes、train 和 dataset_path。def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
# super() 函数用于调用父类的初始化方法。在这里,它调用了 UnetDataset 类的父类的 __init__ 方法,确保父类的初始化逻辑也被执行。这对于继承自其他类的类非常重要。super(UnetDataset, self).__init__()
# self 代表类的实例。self.annotation_lines 将传入的 annotation_lines 参数赋值给实例属性 annotation_linesself.annotation_lines   = annotation_linesself.length             = len(annotation_lines)self.input_shape        = input_shapeself.num_classes        = num_classesself.train              = trainself.dataset_path       = dataset_path

解释 super 和 self

  • super
    super() 函数用于调用父类的方法。在多重继承的情况下,它确保正确调用父类的方法,避免重复调用。这里,它调用了 UnetDataset 类的父类的 init 方法。
  • self
    self 是类的实例的引用。它用于访问类的属性和方法。在类的方法中,self 必须作为第一个参数传递,以便方法能够访问实例的属性和其他方法。

collate_fn

# DataLoader中collate_fn使用
# 函数定义:net_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本batch。
def unet_dataset_collate(batch):
# 初始化列表:
# images = []:用于存储所有图像数据。
# pngs = []:用于存储所有 PNG 格式的数据。
# seg_labels = []:用于存储所有分割标签数据images      = []pngs        = []seg_labels  = []
# 遍历批次数据:
# 遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
# images.append(img):将图像数据添加到 images 列表中。
# pngs.append(png):将 PNG 数据添加到 pngs 列表中。
# seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)
#转换数据类型:
# 将 images 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
# 将 pngs 列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。
# 将 seg_labels 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
# 返回结果:
# 返回处理后的图像数据、PNG 数据和分割标签数据。return images, pngs, seg_labels

详细说明

  1. 函数定义

    • unet_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本 batch
  2. 初始化列表

    • images = []:用于存储所有图像数据。
    • pngs = []:用于存储所有 PNG 格式的数据。
    • seg_labels = []:用于存储所有分割标签数据。
  3. 遍历批次数据

    • for img, png, labels in batch::遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
    • images.append(img):将图像数据添加到 images 列表中。
    • pngs.append(png):将 PNG 数据添加到 pngs 列表中。
    • seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。
  4. 转换数据类型

    • images = torch.from_numpy(np.array(images)).type(torch.FloatTensor):将 images 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
    • pngs = torch.from_numpy(np.array(pngs)).long():将 pngs 列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。
    • seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor):将 seg_labels 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
  5. 返回结果

    • return images, pngs, seg_labels:返回处理后的图像数据、PNG 数据和分割标签数据。

完整代码

import osimport cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Datasetfrom utils.utils import cvtColor, preprocess_inputclass UnetDataset(Dataset):def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):super(UnetDataset, self).__init__()self.annotation_lines   = annotation_linesself.length             = len(annotation_lines)self.input_shape        = input_shapeself.num_classes        = num_classesself.train              = trainself.dataset_path       = dataset_pathdef __len__(self):return self.lengthdef __getitem__(self, index):annotation_line = self.annotation_lines[index]name            = annotation_line.split()[0]#-------------------------------##   从文件中读取图像#-------------------------------#jpg         = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg"))png         = Image.open(os.path.join(os.path.join(self.dataset_path, "SegmentationClass"), name + ".png"))#-------------------------------##   数据增强#-------------------------------#jpg, png    = self.get_random_data(jpg, png, self.input_shape, random = self.train)jpg         = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])png         = np.array(png)png[png >= self.num_classes] = self.num_classes#-------------------------------------------------------##   转化成one_hot的形式#   在这里需要+1是因为voc数据集有些标签具有白边部分#   我们需要将白边部分进行忽略,+1的目的是方便忽略。#-------------------------------------------------------#seg_labels  = np.eye(self.num_classes + 1)[png.reshape([-1])]seg_labels  = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))return jpg, png, seg_labelsdef rand(self, a=0, b=1):return np.random.rand() * (b - a) + adef get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):image   = cvtColor(image)label   = Image.fromarray(np.array(label))#------------------------------##   获得图像的高宽与目标高宽#------------------------------#iw, ih  = image.sizeh, w    = input_shapeif not random:iw, ih  = image.sizescale   = min(w/iw, h/ih)nw      = int(iw*scale)nh      = int(ih*scale)image       = image.resize((nw,nh), Image.BICUBIC)new_image   = Image.new('RGB', [w, h], (128,128,128))new_image.paste(image, ((w-nw)//2, (h-nh)//2))label       = label.resize((nw,nh), Image.NEAREST)new_label   = Image.new('L', [w, h], (0))new_label.paste(label, ((w-nw)//2, (h-nh)//2))return new_image, new_label#------------------------------------------##   对图像进行缩放并且进行长和宽的扭曲#------------------------------------------#new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)scale = self.rand(0.25, 2)if new_ar < 1:nh = int(scale*h)nw = int(nh*new_ar)else:nw = int(scale*w)nh = int(nw/new_ar)image = image.resize((nw,nh), Image.BICUBIC)label = label.resize((nw,nh), Image.NEAREST)#------------------------------------------##   翻转图像#------------------------------------------#flip = self.rand()<.5if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)label = label.transpose(Image.FLIP_LEFT_RIGHT)#------------------------------------------##   将图像多余的部分加上灰条#------------------------------------------#dx = int(self.rand(0, w-nw))dy = int(self.rand(0, h-nh))new_image = Image.new('RGB', (w,h), (128,128,128))new_label = Image.new('L', (w,h), (0))new_image.paste(image, (dx, dy))new_label.paste(label, (dx, dy))image = new_imagelabel = new_labelimage_data      = np.array(image, np.uint8)#---------------------------------##   对图像进行色域变换#   计算色域变换的参数#---------------------------------#r               = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1#---------------------------------##   将图像转到HSV上#---------------------------------#hue, sat, val   = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))dtype           = image_data.dtype#---------------------------------##   应用变换#---------------------------------#x       = np.arange(0, 256, dtype=r.dtype)lut_hue = ((x * r[0]) % 180).astype(dtype)lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)lut_val = np.clip(x * r[2], 0, 255).astype(dtype)image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)return image_data, label# DataLoader中collate_fn使用
def unet_dataset_collate(batch):images      = []pngs        = []seg_labels  = []for img, png, labels in batch:images.append(img)pngs.append(png)seg_labels.append(labels)images      = torch.from_numpy(np.array(images)).type(torch.FloatTensor)pngs        = torch.from_numpy(np.array(pngs)).long()seg_labels  = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)return images, pngs, seg_labels
关键字:深度学习--数据处理dataloader介绍及代码分析

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: