You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
7.2 KiB
200 lines
7.2 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import warnings
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
|
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
|
|
get_dist_info)
|
|
from mmcv.utils import digit_version
|
|
|
|
from mmpose.core import DistEvalHook, EvalHook, build_optimizers
|
|
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
|
|
from mmpose.datasets import build_dataloader, build_dataset
|
|
from mmpose.utils import get_root_logger
|
|
|
|
try:
|
|
from mmcv.runner import Fp16OptimizerHook
|
|
except ImportError:
|
|
warnings.warn(
|
|
'Fp16OptimizerHook from mmpose will be deprecated from '
|
|
'v0.15.0. Please install mmcv>=1.1.4', DeprecationWarning)
|
|
from mmpose.core import Fp16OptimizerHook
|
|
|
|
|
|
def init_random_seed(seed=None, device='cuda'):
|
|
"""Initialize random seed.
|
|
|
|
If the seed is not set, the seed will be automatically randomized,
|
|
and then broadcast to all processes to prevent some potential bugs.
|
|
|
|
Args:
|
|
seed (int, Optional): The seed. Default to None.
|
|
device (str): The device where the seed will be put on.
|
|
Default to 'cuda'.
|
|
|
|
Returns:
|
|
int: Seed to be used.
|
|
"""
|
|
if seed is not None:
|
|
return seed
|
|
|
|
# Make sure all ranks share the same random seed to prevent
|
|
# some potential bugs. Please refer to
|
|
# https://github.com/open-mmlab/mmdetection/issues/6339
|
|
rank, world_size = get_dist_info()
|
|
seed = np.random.randint(2**31)
|
|
if world_size == 1:
|
|
return seed
|
|
|
|
if rank == 0:
|
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
else:
|
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
dist.broadcast(random_num, src=0)
|
|
return random_num.item()
|
|
|
|
|
|
def train_model(model,
|
|
dataset,
|
|
cfg,
|
|
distributed=False,
|
|
validate=False,
|
|
timestamp=None,
|
|
meta=None):
|
|
"""Train model entry function.
|
|
|
|
Args:
|
|
model (nn.Module): The model to be trained.
|
|
dataset (Dataset): Train dataset.
|
|
cfg (dict): The config dict for training.
|
|
distributed (bool): Whether to use distributed training.
|
|
Default: False.
|
|
validate (bool): Whether to do evaluation. Default: False.
|
|
timestamp (str | None): Local time for runner. Default: None.
|
|
meta (dict | None): Meta dict to record some important information.
|
|
Default: None
|
|
"""
|
|
logger = get_root_logger(cfg.log_level)
|
|
|
|
# prepare data loaders
|
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
|
# step 1: give default values and override (if exist) from cfg.data
|
|
loader_cfg = {
|
|
**dict(
|
|
seed=cfg.get('seed'),
|
|
drop_last=False,
|
|
dist=distributed,
|
|
num_gpus=len(cfg.gpu_ids)),
|
|
**({} if torch.__version__ != 'parrots' else dict(
|
|
prefetch_num=2,
|
|
pin_memory=False,
|
|
)),
|
|
**dict((k, cfg.data[k]) for k in [
|
|
'samples_per_gpu',
|
|
'workers_per_gpu',
|
|
'shuffle',
|
|
'seed',
|
|
'drop_last',
|
|
'prefetch_num',
|
|
'pin_memory',
|
|
'persistent_workers',
|
|
] if k in cfg.data)
|
|
}
|
|
|
|
# step 2: cfg.data.train_dataloader has highest priority
|
|
train_loader_cfg = dict(loader_cfg, **cfg.data.get('train_dataloader', {}))
|
|
|
|
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
|
|
|
# determine whether use adversarial training precess or not
|
|
use_adverserial_train = cfg.get('use_adversarial_train', False)
|
|
|
|
# put model on gpus
|
|
if distributed:
|
|
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
|
# Sets the `find_unused_parameters` parameter in
|
|
# torch.nn.parallel.DistributedDataParallel
|
|
|
|
if use_adverserial_train:
|
|
# Use DistributedDataParallelWrapper for adversarial training
|
|
model = DistributedDataParallelWrapper(
|
|
model,
|
|
device_ids=[torch.cuda.current_device()],
|
|
broadcast_buffers=False,
|
|
find_unused_parameters=find_unused_parameters)
|
|
else:
|
|
model = MMDistributedDataParallel(
|
|
model.cuda(),
|
|
device_ids=[torch.cuda.current_device()],
|
|
broadcast_buffers=False,
|
|
find_unused_parameters=find_unused_parameters)
|
|
else:
|
|
if digit_version(mmcv.__version__) >= digit_version(
|
|
'1.4.4') or torch.cuda.is_available():
|
|
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
|
else:
|
|
warnings.warn(
|
|
'We recommend to use MMCV >= 1.4.4 for CPU training. '
|
|
'See https://github.com/open-mmlab/mmpose/pull/1157 for '
|
|
'details.')
|
|
|
|
# build runner
|
|
optimizer = build_optimizers(model, cfg.optimizer)
|
|
|
|
runner = EpochBasedRunner(
|
|
model,
|
|
optimizer=optimizer,
|
|
work_dir=cfg.work_dir,
|
|
logger=logger,
|
|
meta=meta)
|
|
# an ugly workaround to make .log and .log.json filenames the same
|
|
runner.timestamp = timestamp
|
|
|
|
if use_adverserial_train:
|
|
# The optimizer step process is included in the train_step function
|
|
# of the model, so the runner should NOT include optimizer hook.
|
|
optimizer_config = None
|
|
else:
|
|
# fp16 setting
|
|
fp16_cfg = cfg.get('fp16', None)
|
|
if fp16_cfg is not None:
|
|
optimizer_config = Fp16OptimizerHook(
|
|
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
|
elif distributed and 'type' not in cfg.optimizer_config:
|
|
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
|
else:
|
|
optimizer_config = cfg.optimizer_config
|
|
|
|
# register hooks
|
|
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
|
cfg.checkpoint_config, cfg.log_config,
|
|
cfg.get('momentum_config', None))
|
|
if distributed:
|
|
runner.register_hook(DistSamplerSeedHook())
|
|
|
|
# register eval hooks
|
|
if validate:
|
|
eval_cfg = cfg.get('evaluation', {})
|
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
|
dataloader_setting = dict(
|
|
samples_per_gpu=1,
|
|
workers_per_gpu=cfg.data.get('workers_per_gpu', 1),
|
|
# cfg.gpus will be ignored if distributed
|
|
num_gpus=len(cfg.gpu_ids),
|
|
dist=distributed,
|
|
drop_last=False,
|
|
shuffle=False)
|
|
dataloader_setting = dict(dataloader_setting,
|
|
**cfg.data.get('val_dataloader', {}))
|
|
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
|
|
eval_hook = DistEvalHook if distributed else EvalHook
|
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
|
|
|
if cfg.resume_from:
|
|
runner.resume(cfg.resume_from)
|
|
elif cfg.load_from:
|
|
runner.load_checkpoint(cfg.load_from)
|
|
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|
|
|