You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.6 KiB
86 lines
2.6 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import ABCMeta, abstractmethod, abstractproperty
|
|
|
|
import torch
|
|
|
|
|
|
class PytorchModuleHook(metaclass=ABCMeta):
|
|
"""Base class for PyTorch module hook registers.
|
|
|
|
An instance of a subclass of PytorchModuleHook can be used to
|
|
register hook to a pytorch module using the `register` method like:
|
|
hook_register.register(module)
|
|
|
|
Subclasses should add/overwrite the following methods:
|
|
- __init__
|
|
- hook
|
|
- hook_type
|
|
"""
|
|
|
|
@abstractmethod
|
|
def hook(self, *args, **kwargs):
|
|
"""Hook function."""
|
|
|
|
@abstractproperty
|
|
def hook_type(self) -> str:
|
|
"""Hook type Subclasses should overwrite this function to return a
|
|
string value in.
|
|
|
|
{`forward`, `forward_pre`, `backward`}
|
|
"""
|
|
|
|
def register(self, module):
|
|
"""Register the hook function to the module.
|
|
|
|
Args:
|
|
module (pytorch module): the module to register the hook.
|
|
|
|
Returns:
|
|
handle (torch.utils.hooks.RemovableHandle): a handle to remove
|
|
the hook by calling handle.remove()
|
|
"""
|
|
assert isinstance(module, torch.nn.Module)
|
|
|
|
if self.hook_type == 'forward':
|
|
h = module.register_forward_hook(self.hook)
|
|
elif self.hook_type == 'forward_pre':
|
|
h = module.register_forward_pre_hook(self.hook)
|
|
elif self.hook_type == 'backward':
|
|
h = module.register_backward_hook(self.hook)
|
|
else:
|
|
raise ValueError(f'Invalid hook type {self.hook}')
|
|
|
|
return h
|
|
|
|
|
|
class WeightNormClipHook(PytorchModuleHook):
|
|
"""Apply weight norm clip regularization.
|
|
|
|
The module's parameter will be clip to a given maximum norm before each
|
|
forward pass.
|
|
|
|
Args:
|
|
max_norm (float): The maximum norm of the parameter.
|
|
module_param_names (str|list): The parameter name (or name list) to
|
|
apply weight norm clip.
|
|
"""
|
|
|
|
def __init__(self, max_norm=1.0, module_param_names='weight'):
|
|
self.module_param_names = module_param_names if isinstance(
|
|
module_param_names, list) else [module_param_names]
|
|
self.max_norm = max_norm
|
|
|
|
@property
|
|
def hook_type(self):
|
|
return 'forward_pre'
|
|
|
|
def hook(self, module, _input):
|
|
for name in self.module_param_names:
|
|
assert name in module._parameters, f'{name} is not a parameter' \
|
|
f' of the module {type(module)}'
|
|
param = module._parameters[name]
|
|
|
|
with torch.no_grad():
|
|
m = param.norm().item()
|
|
if m > self.max_norm:
|
|
param.mul_(self.max_norm / (m + 1e-6))
|
|
|