当前位置: 首页> 游戏> 手游 > 利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

时间:2025/7/12 3:18:12来源:https://blog.csdn.net/fydw_715/article/details/141217945 浏览次数:0次

Python 提供了许多灵活的特性,例如包的 __init__.py 文件和 getattr 函数,这些特性可以帮助我们实现工厂方法模式来动态地创建不同类型的数据集实例。

1. 背景介绍

在深度学习项目中,我们通常需要处理多种类型的数据集,例如 COCO、Pascal VOC 和自定义的交通数据集。为了简化和统一数据集的加载过程,我们可以利用 Python 的包管理和动态属性获取特性来实现工厂方法模式。

  • 包的 __init__.py 文件:通过在包的 __init__.py 文件中导入模块,我们可以在初始化包时自动加载所有必要的类和函数。
  • getattr 函数getattr 函数允许我们动态地获取对象的属性或方法,这对于实现工厂方法模式非常有用,因为我们可以根据配置或输入动态地创建对象,而无需在代码中硬编码每种数据集的构建逻辑。

接下来,我们将通过具体的代码示例来展示如何使用这些特性来实现数据集的动态加载。

2. 模块和类的定义

在我们的项目中,数据集类被定义在 datasets 模块中。我们将定义一个 COCODataset 类,并在 datasets 模块的 __init__.py 文件中导入它。需要注意的是,COCODataset 只是众多数据集类中的一种,其他数据集类如 PascalVOCDatasetTrafficDataset 等也可以通过类似的方式定义和使用。

定义 COCODataset

datasets 模块中创建一个名为 coco.py 的文件,并定义 COCODataset 类。这个类继承自 torchvision.datasets.coco.CocoDetection,并添加了一些自定义逻辑。

# datasets/coco.py
import torchvisionclass COCODataset(torchvision.datasets.coco.CocoDetection):def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):super(COCODataset, self).__init__(root, ann_file)# 自定义逻辑...
  • __init__ 方法COCODataset 类的构造函数接受 ann_file(注释文件路径)、root(图像根目录)、remove_images_without_annotations(是否移除没有注释的图像)和 transforms(图像变换)四个参数。这些参数与后面 DatasetCatalogget 方法返回的 args 对应。
  • 详细实现见附录
导入 COCODataset

datasets 模块的 __init__.py 文件中导入 COCODataset 类。这样可以确保在使用 datasets 模块时,所有数据集类都已加载。

# datasets/__init__.py
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .traffic_dataset import TrafficDataset
from .carWinBiaoZhi_dataset import CarWinBiaoZhiDataset
from .carWinBiaoZhi_dataset_V2 import CarWinBiaoZhiDatasetV2
from .carWinBiaoZhi_dataset_V2_1 import CarWinBiaoZhiDatasetV2_1
from .GsData import CgTrafficData
from .GsData_xianQuan import CgTrafficDataWithXianQuan
from .GsData_1cls import CgTrafficData1Cls
from .GsData_ForSemi import CgTrafficDataSemi
from .GsRadarData import CgTrafficRadarData__all__ = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "TrafficDataset","CarWinBiaoZhiDataset", "CarWinBiaoZhiDatasetV2", "CarWinBiaoZhiDatasetV2_1", "CgTrafficData", "CgTrafficDataWithXianQuan", "CgTrafficDataSemi", "CgTrafficRadarData", "CgTrafficData1Cls"
]

3. 使用 getattr 动态获取工厂方法

在构建数据集实例时,我们通过 getattr 函数动态获取工厂方法。以下是实现这一过程的核心代码:

# build_dataset.py
from . import datasets as Ddef build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):if not isinstance(dataset_list, (list, tuple)):raise RuntimeError("dataset_list 应该是一个字符串列表,得到的是 {}".format(dataset_list))datasets = []  # 初始化数据集列表for dataset_name in dataset_list:# 从 dataset_catalog 中获取数据集信息data = dataset_catalog.get(dataset_name)# 获取数据集的工厂方法factory = getattr(D, data["factory"])# 获取数据集的参数args = data["args"]# 设置数据集的变换args["transforms"] = transforms# 使用工厂方法创建数据集实例dataset = factory(**args)# 将创建的数据集添加到列表中datasets.append(dataset)# 如果是测试模式,返回数据集列表if not is_train:return datasets# 如果是训练模式,将所有数据集合并为一个数据集dataset = datasets[0]if len(datasets) > 1:dataset = D.ConcatDataset(datasets)return [dataset]

4. 数据集目录管理 (DatasetCatalog)

为了集中管理数据集的路径和相关信息,我们定义了 DatasetCatalog 类。这个类包含了所有数据集的配置信息,并提供了一个静态方法 get 来获取特定数据集的配置信息。

# paths_catalog.py
import osclass DatasetCatalog(object):DATA_DIR = "/home/Public_DataSets"DATASETS = {"coco_2017_train": {"img_dir": "coco/train2017","ann_file": "coco/annotations/instances_train2017.json"},"voc_2007_train": {"data_dir": "voc/VOC2007","split": "train"},# ... 其他数据集配置 ...}@staticmethoddef get(name):if "coco" in name:data_dir = DatasetCatalog.DATA_DIRattrs = DatasetCatalog.DATASETS[name]args = dict(root=os.path.join(data_dir, attrs["img_dir"]),ann_file=os.path.join(data_dir, attrs["ann_file"]),)return dict(factory="COCODataset",args=args,)elif "voc" in name:data_dir = DatasetCatalog.DATA_DIRattrs = DatasetCatalog.DATASETS[name]args = dict(data_dir=os.path.join(data_dir, attrs["data_dir"]),split=attrs["split"],)return dict(factory="PascalVOCDataset",args=args,)# ... 其他数据集配置 ...raise RuntimeError("Dataset not available: {}".format(name))
说明

get 方法中,我们根据数据集名称动态生成配置字典。例如,对于 COCO 数据集:

return dict(factory="COCODataset",args=args,
)
  • factory:指定数据集类的名称,在后续步骤中用于动态获取工厂方法。
  • args:包含构建数据集实例所需的参数。

5. COCO 数据集的举例说明

假设我们有一个名为 "coco_2017_train" 的数据集,我们希望使用 DatasetCatalog 和工厂方法来加载这个数据集。以下是具体的步骤:

  1. 定义数据集配置

    # paths_catalog.py 中的 DATASETS 字典
    DATASETS = {"coco_2017_train": {"img_dir": "coco/train2017","ann_file": "coco/annotations/instances_train2017.json"},# ... 其他数据集配置 ...
    }
    
  2. 获取数据集配置

    data = DatasetCatalog.get("coco_2017_train")
    
  3. 动态获取工厂方法

    factory = getattr(D, data["factory"])
    
  4. 创建数据集实例

    args = data["args"]
    args["transforms"] = some_transform_function  # 假设我们有一个变换函数
    dataset = factory(**args)
    

通过这种方式,我们可以动态地加载 COCO 数据集,而无需硬编码每种数据集的构建逻辑。这种设计模式提高了代码的灵活性和可维护性,使得数据集的管理和加载更加方便。

附录: COCODataset 类完整实现
# datasets/coco.py
import torch
import torchvision
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypointsmin_keypoints_per_image = 10def has_valid_annotation(anno):if len(anno) == 0:return Falseif all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno):return Falseif "keypoints" not in anno[0]:return Trueif sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) >= min_keypoints_per_image:return Truereturn Falseclass COCODataset(torchvision.datasets.coco.CocoDetection):def __init__(self, ann_file, root, remove_images_without_annotations, transforms=None):super(COCODataset, self).__init__(root, ann_file)self.ids = sorted(self.ids)if remove_images_without_annotations:self.ids = [img_id for img_id in self.ids if has_valid_annotation(self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)))]self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}self.json_category_id_to_contiguous_id = {v: i + 1 for i, v in enumerate(self.coco.getCatIds())}self.contiguous_category_id_to_json_id = {v: k for k, v in self.json_category_id_to_contiguous_id.items()}self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}self._transforms = transformsdef __getitem__(self, idx):img, anno = super(COCODataset, self).__getitem__(idx)anno = [obj for obj in anno if obj["iscrowd"] == 0]boxes = [obj["bbox"] for obj in anno]boxes = torch.as_tensor(boxes).reshape(-1, 4)target = BoxList(boxes, img.size, mode="xywh").convert("xyxy")classes = torch.tensor([self.json_category_id_to_contiguous_id[obj["category_id"]] for obj in anno])target.add_field("labels", classes)if anno and "segmentation" in anno[0]:masks = SegmentationMask([obj["segmentation"] for obj in anno], img.size, mode='poly')target.add_field("masks", masks)if anno and "keypoints" in anno[0]:keypoints = PersonKeypoints([obj["keypoints"] for obj in anno], img.size)target.add_field("keypoints", keypoints)target = target.clip_to_image(remove_empty=True)if self._transforms is not None:img, target = self._transforms(img, target)return img, target, idxdef get_img_info(self, index):return self.coco.imgs[self.id_to_img_map[index]]
  • __init__ 方法:初始化数据集,加载注释,过滤无效注释,并设置类别和图像映射。
  • __getitem__ 方法:获取指定索引的图像和注释,应用可选的变换,并返回图像、目标和索引。
关键字:利用 Python 的包管理和动态属性获取(`__init__.py` 文件和 `getattr` 函数)特性来实现工厂方法模式

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: