# Copyright (c) OpenMMLab. All rights reserved. import warnings from collections.abc import Sequence import mmcv import numpy as np from mmcv.parallel import DataContainer as DC from mmcv.utils import build_from_cfg from numpy import random from torchvision.transforms import functional as F from ..builder import PIPELINES try: import albumentations except ImportError: albumentations = None @PIPELINES.register_module() class ToTensorSam: """Transform image to Tensor. Required key: 'img'. Modifies key: 'img'. Args: results (dict): contain all information about training. """ def __call__(self, results): if isinstance(results['img'], (list, tuple)): results['img'] = [F.to_tensor(img) for img in results['img']] # 修改 results['sam_img'] = [F.to_tensor(sam_img) for sam_img in results['sam_img']] else: results['img'] = F.to_tensor(results['img']) # 修改 results['sam_img'] = F.to_tensor(results['sam_img']) return results @PIPELINES.register_module() class NormalizeTensorSam: """Normalize the Tensor image (CxHxW), with mean and std. Required key: 'img'. Modifies key: 'img'. Args: mean (list[float]): Mean values of 3 channels. std (list[float]): Std values of 3 channels. """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, results): if isinstance(results['img'], (list, tuple)): results['img'] = [ F.normalize(img, mean=self.mean, std=self.std) for img in results['img'] ] # 修改 results['sam_img'] = [ F.normalize(sam_img, mean=self.mean, std=self.std) for sam_img in results['sam_img'] ] else: results['img'] = F.normalize( results['img'], mean=self.mean, std=self.std) # 修改 results['sam_img'] = F.normalize( results['sam_img'], mean=self.mean, std=self.std) return results