You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.1 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from collections.abc import Sequence
import mmcv
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmcv.utils import build_from_cfg
from numpy import random
from torchvision.transforms import functional as F
from ..builder import PIPELINES
try:
import albumentations
except ImportError:
albumentations = None
@PIPELINES.register_module()
class ToTensorSam:
"""Transform image to Tensor.
Required key: 'img'. Modifies key: 'img'.
Args:
results (dict): contain all information about training.
"""
def __call__(self, results):
if isinstance(results['img'], (list, tuple)):
results['img'] = [F.to_tensor(img) for img in results['img']]
# 修改
results['sam_img'] = [F.to_tensor(sam_img) for sam_img in results['sam_img']]
else:
results['img'] = F.to_tensor(results['img'])
# 修改
results['sam_img'] = F.to_tensor(results['sam_img'])
return results
@PIPELINES.register_module()
class NormalizeTensorSam:
"""Normalize the Tensor image (CxHxW), with mean and std.
Required key: 'img'. Modifies key: 'img'.
Args:
mean (list[float]): Mean values of 3 channels.
std (list[float]): Std values of 3 channels.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, results):
if isinstance(results['img'], (list, tuple)):
results['img'] = [
F.normalize(img, mean=self.mean, std=self.std)
for img in results['img']
]
# 修改
results['sam_img'] = [
F.normalize(sam_img, mean=self.mean, std=self.std)
for sam_img in results['sam_img']
]
else:
results['img'] = F.normalize(
results['img'], mean=self.mean, std=self.std)
# 修改
results['sam_img'] = F.normalize(
results['sam_img'], mean=self.mean, std=self.std)
return results