You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
6.7 KiB
175 lines
6.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import functools
|
|
import warnings
|
|
from inspect import getfullargspec
|
|
|
|
import torch
|
|
|
|
from .utils import cast_tensor_type
|
|
|
|
|
|
def auto_fp16(apply_to=None, out_fp32=False):
|
|
"""Decorator to enable fp16 training automatically.
|
|
|
|
This decorator is useful when you write custom modules and want to support
|
|
mixed precision training. If inputs arguments are fp32 tensors, they will
|
|
be converted to fp16 automatically. Arguments other than fp32 tensors are
|
|
ignored.
|
|
|
|
Args:
|
|
apply_to (Iterable, optional): The argument names to be converted.
|
|
`None` indicates all arguments.
|
|
out_fp32 (bool): Whether to convert the output back to fp32.
|
|
|
|
Example:
|
|
|
|
>>> import torch.nn as nn
|
|
>>> class MyModule1(nn.Module):
|
|
>>>
|
|
>>> # Convert x and y to fp16
|
|
>>> @auto_fp16()
|
|
>>> def forward(self, x, y):
|
|
>>> pass
|
|
|
|
>>> import torch.nn as nn
|
|
>>> class MyModule2(nn.Module):
|
|
>>>
|
|
>>> # convert pred to fp16
|
|
>>> @auto_fp16(apply_to=('pred', ))
|
|
>>> def do_something(self, pred, others):
|
|
>>> pass
|
|
"""
|
|
|
|
warnings.warn(
|
|
'auto_fp16 in mmpose will be deprecated in the next release.'
|
|
'Please use mmcv.runner.auto_fp16 instead (mmcv>=1.3.1).',
|
|
DeprecationWarning)
|
|
|
|
def auto_fp16_wrapper(old_func):
|
|
|
|
@functools.wraps(old_func)
|
|
def new_func(*args, **kwargs):
|
|
# check if the module has set the attribute `fp16_enabled`, if not,
|
|
# just fallback to the original method.
|
|
if not isinstance(args[0], torch.nn.Module):
|
|
raise TypeError('@auto_fp16 can only be used to decorate the '
|
|
'method of nn.Module')
|
|
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
|
return old_func(*args, **kwargs)
|
|
# get the arg spec of the decorated method
|
|
args_info = getfullargspec(old_func)
|
|
# get the argument names to be casted
|
|
args_to_cast = args_info.args if apply_to is None else apply_to
|
|
# convert the args that need to be processed
|
|
new_args = []
|
|
# NOTE: default args are not taken into consideration
|
|
if args:
|
|
arg_names = args_info.args[:len(args)]
|
|
for i, arg_name in enumerate(arg_names):
|
|
if arg_name in args_to_cast:
|
|
new_args.append(
|
|
cast_tensor_type(args[i], torch.float, torch.half))
|
|
else:
|
|
new_args.append(args[i])
|
|
# convert the kwargs that need to be processed
|
|
new_kwargs = {}
|
|
if kwargs:
|
|
for arg_name, arg_value in kwargs.items():
|
|
if arg_name in args_to_cast:
|
|
new_kwargs[arg_name] = cast_tensor_type(
|
|
arg_value, torch.float, torch.half)
|
|
else:
|
|
new_kwargs[arg_name] = arg_value
|
|
# apply converted arguments to the decorated method
|
|
output = old_func(*new_args, **new_kwargs)
|
|
# cast the results back to fp32 if necessary
|
|
if out_fp32:
|
|
output = cast_tensor_type(output, torch.half, torch.float)
|
|
return output
|
|
|
|
return new_func
|
|
|
|
return auto_fp16_wrapper
|
|
|
|
|
|
def force_fp32(apply_to=None, out_fp16=False):
|
|
"""Decorator to convert input arguments to fp32 in force.
|
|
|
|
This decorator is useful when you write custom modules and want to support
|
|
mixed precision training. If there are some inputs that must be processed
|
|
in fp32 mode, then this decorator can handle it. If inputs arguments are
|
|
fp16 tensors, they will be converted to fp32 automatically. Arguments other
|
|
than fp16 tensors are ignored.
|
|
|
|
Args:
|
|
apply_to (Iterable, optional): The argument names to be converted.
|
|
`None` indicates all arguments.
|
|
out_fp16 (bool): Whether to convert the output back to fp16.
|
|
|
|
Example:
|
|
|
|
>>> import torch.nn as nn
|
|
>>> class MyModule1(nn.Module):
|
|
>>>
|
|
>>> # Convert x and y to fp32
|
|
>>> @force_fp32()
|
|
>>> def loss(self, x, y):
|
|
>>> pass
|
|
|
|
>>> import torch.nn as nn
|
|
>>> class MyModule2(nn.Module):
|
|
>>>
|
|
>>> # convert pred to fp32
|
|
>>> @force_fp32(apply_to=('pred', ))
|
|
>>> def post_process(self, pred, others):
|
|
>>> pass
|
|
"""
|
|
warnings.warn(
|
|
'force_fp32 in mmpose will be deprecated in the next release.'
|
|
'Please use mmcv.runner.force_fp32 instead (mmcv>=1.3.1).',
|
|
DeprecationWarning)
|
|
|
|
def force_fp32_wrapper(old_func):
|
|
|
|
@functools.wraps(old_func)
|
|
def new_func(*args, **kwargs):
|
|
# check if the module has set the attribute `fp16_enabled`, if not,
|
|
# just fallback to the original method.
|
|
if not isinstance(args[0], torch.nn.Module):
|
|
raise TypeError('@force_fp32 can only be used to decorate the '
|
|
'method of nn.Module')
|
|
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
|
return old_func(*args, **kwargs)
|
|
# get the arg spec of the decorated method
|
|
args_info = getfullargspec(old_func)
|
|
# get the argument names to be casted
|
|
args_to_cast = args_info.args if apply_to is None else apply_to
|
|
# convert the args that need to be processed
|
|
new_args = []
|
|
if args:
|
|
arg_names = args_info.args[:len(args)]
|
|
for i, arg_name in enumerate(arg_names):
|
|
if arg_name in args_to_cast:
|
|
new_args.append(
|
|
cast_tensor_type(args[i], torch.half, torch.float))
|
|
else:
|
|
new_args.append(args[i])
|
|
# convert the kwargs that need to be processed
|
|
new_kwargs = dict()
|
|
if kwargs:
|
|
for arg_name, arg_value in kwargs.items():
|
|
if arg_name in args_to_cast:
|
|
new_kwargs[arg_name] = cast_tensor_type(
|
|
arg_value, torch.half, torch.float)
|
|
else:
|
|
new_kwargs[arg_name] = arg_value
|
|
# apply converted arguments to the decorated method
|
|
output = old_func(*new_args, **new_kwargs)
|
|
# cast the results back to fp32 if necessary
|
|
if out_fp16:
|
|
output = cast_tensor_type(output, torch.float, torch.half)
|
|
return output
|
|
|
|
return new_func
|
|
|
|
return force_fp32_wrapper
|
|
|