You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1017 B
34 lines
1017 B
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import abc
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def cast_tensor_type(inputs, src_type, dst_type):
|
|
"""Recursively convert Tensor in inputs from src_type to dst_type.
|
|
|
|
Args:
|
|
inputs: Inputs that to be casted.
|
|
src_type (torch.dtype): Source type.
|
|
dst_type (torch.dtype): Destination type.
|
|
|
|
Returns:
|
|
The same type with inputs, but all contained Tensors have been cast.
|
|
"""
|
|
if isinstance(inputs, torch.Tensor):
|
|
return inputs.to(dst_type)
|
|
elif isinstance(inputs, str):
|
|
return inputs
|
|
elif isinstance(inputs, np.ndarray):
|
|
return inputs
|
|
elif isinstance(inputs, abc.Mapping):
|
|
return type(inputs)({
|
|
k: cast_tensor_type(v, src_type, dst_type)
|
|
for k, v in inputs.items()
|
|
})
|
|
elif isinstance(inputs, abc.Iterable):
|
|
return type(inputs)(
|
|
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
|
|
|
return inputs
|
|
|