You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.1 KiB
76 lines
2.1 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
from collections.abc import Sequence
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from mmcv.parallel import DataContainer as DC
|
|
from mmcv.utils import build_from_cfg
|
|
from numpy import random
|
|
from torchvision.transforms import functional as F
|
|
|
|
from ..builder import PIPELINES
|
|
|
|
try:
|
|
import albumentations
|
|
except ImportError:
|
|
albumentations = None
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class ToTensorSam:
|
|
"""Transform image to Tensor.
|
|
|
|
Required key: 'img'. Modifies key: 'img'.
|
|
|
|
Args:
|
|
results (dict): contain all information about training.
|
|
"""
|
|
|
|
def __call__(self, results):
|
|
if isinstance(results['img'], (list, tuple)):
|
|
results['img'] = [F.to_tensor(img) for img in results['img']]
|
|
# 修改
|
|
results['sam_img'] = [F.to_tensor(sam_img) for sam_img in results['sam_img']]
|
|
else:
|
|
results['img'] = F.to_tensor(results['img'])
|
|
# 修改
|
|
results['sam_img'] = F.to_tensor(results['sam_img'])
|
|
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class NormalizeTensorSam:
|
|
"""Normalize the Tensor image (CxHxW), with mean and std.
|
|
|
|
Required key: 'img'. Modifies key: 'img'.
|
|
|
|
Args:
|
|
mean (list[float]): Mean values of 3 channels.
|
|
std (list[float]): Std values of 3 channels.
|
|
"""
|
|
|
|
def __init__(self, mean, std):
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def __call__(self, results):
|
|
if isinstance(results['img'], (list, tuple)):
|
|
results['img'] = [
|
|
F.normalize(img, mean=self.mean, std=self.std)
|
|
for img in results['img']
|
|
]
|
|
# 修改
|
|
results['sam_img'] = [
|
|
F.normalize(sam_img, mean=self.mean, std=self.std)
|
|
for sam_img in results['sam_img']
|
|
]
|
|
else:
|
|
results['img'] = F.normalize(
|
|
results['img'], mean=self.mean, std=self.std)
|
|
# 修改
|
|
results['sam_img'] = F.normalize(
|
|
results['sam_img'], mean=self.mean, std=self.std)
|
|
|
|
return results
|
|
|