You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.8 KiB
103 lines
3.8 KiB
# Copyright (c) Open-MMLab. All rights reserved.
|
|
import os.path as osp
|
|
import platform
|
|
import shutil
|
|
|
|
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
import mmcv
|
|
from mmcv.runner import RUNNERS, IterBasedRunner
|
|
from .checkpoint import save_checkpoint
|
|
|
|
try:
|
|
import apex
|
|
except:
|
|
print('apex is not installed')
|
|
|
|
|
|
@RUNNERS.register_module()
|
|
class IterBasedRunnerAmp(IterBasedRunner):
|
|
"""Iteration-based Runner with AMP support.
|
|
|
|
This runner train models iteration by iteration.
|
|
"""
|
|
|
|
def save_checkpoint(self,
|
|
out_dir,
|
|
filename_tmpl='iter_{}.pth',
|
|
meta=None,
|
|
save_optimizer=True,
|
|
create_symlink=False):
|
|
"""Save checkpoint to file.
|
|
|
|
Args:
|
|
out_dir (str): Directory to save checkpoint files.
|
|
filename_tmpl (str, optional): Checkpoint file template.
|
|
Defaults to 'iter_{}.pth'.
|
|
meta (dict, optional): Metadata to be saved in checkpoint.
|
|
Defaults to None.
|
|
save_optimizer (bool, optional): Whether save optimizer.
|
|
Defaults to True.
|
|
create_symlink (bool, optional): Whether create symlink to the
|
|
latest checkpoint file. Defaults to True.
|
|
"""
|
|
if meta is None:
|
|
meta = dict(iter=self.iter + 1, epoch=self.epoch + 1)
|
|
elif isinstance(meta, dict):
|
|
meta.update(iter=self.iter + 1, epoch=self.epoch + 1)
|
|
else:
|
|
raise TypeError(
|
|
f'meta should be a dict or None, but got {type(meta)}')
|
|
if self.meta is not None:
|
|
meta.update(self.meta)
|
|
|
|
filename = filename_tmpl.format(self.iter + 1)
|
|
filepath = osp.join(out_dir, filename)
|
|
optimizer = self.optimizer if save_optimizer else None
|
|
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
|
# in some environments, `os.symlink` is not supported, you may need to
|
|
# set `create_symlink` to False
|
|
# if create_symlink:
|
|
# dst_file = osp.join(out_dir, 'latest.pth')
|
|
# if platform.system() != 'Windows':
|
|
# mmcv.symlink(filename, dst_file)
|
|
# else:
|
|
# shutil.copy(filepath, dst_file)
|
|
|
|
def resume(self,
|
|
checkpoint,
|
|
resume_optimizer=True,
|
|
map_location='default'):
|
|
if map_location == 'default':
|
|
if torch.cuda.is_available():
|
|
device_id = torch.cuda.current_device()
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint,
|
|
map_location=lambda storage, loc: storage.cuda(device_id))
|
|
else:
|
|
checkpoint = self.load_checkpoint(checkpoint)
|
|
else:
|
|
checkpoint = self.load_checkpoint(
|
|
checkpoint, map_location=map_location)
|
|
|
|
self._epoch = checkpoint['meta']['epoch']
|
|
self._iter = checkpoint['meta']['iter']
|
|
self._inner_iter = checkpoint['meta']['iter']
|
|
if 'optimizer' in checkpoint and resume_optimizer:
|
|
if isinstance(self.optimizer, Optimizer):
|
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
|
elif isinstance(self.optimizer, dict):
|
|
for k in self.optimizer.keys():
|
|
self.optimizer[k].load_state_dict(
|
|
checkpoint['optimizer'][k])
|
|
else:
|
|
raise TypeError(
|
|
'Optimizer should be dict or torch.optim.Optimizer '
|
|
f'but got {type(self.optimizer)}')
|
|
|
|
if 'amp' in checkpoint:
|
|
apex.amp.load_state_dict(checkpoint['amp'])
|
|
self.logger.info('load amp state dict')
|
|
|
|
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
|
|
|