You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
552 lines
21 KiB
552 lines
21 KiB
# Copyright (c) Open-MMLab. All rights reserved.
|
|
import io
|
|
import os
|
|
import os.path as osp
|
|
import pkgutil
|
|
import time
|
|
import warnings
|
|
from collections import OrderedDict
|
|
from importlib import import_module
|
|
from tempfile import TemporaryDirectory
|
|
|
|
import torch
|
|
import torchvision
|
|
from torch.optim import Optimizer
|
|
from torch.utils import model_zoo
|
|
from torch.nn import functional as F
|
|
|
|
import mmcv
|
|
from mmcv.fileio import FileClient
|
|
from mmcv.fileio import load as load_file
|
|
from mmcv.parallel import is_module_wrapper
|
|
from mmcv.utils import mkdir_or_exist
|
|
from mmcv.runner import get_dist_info
|
|
|
|
from scipy import interpolate
|
|
import numpy as np
|
|
import math
|
|
import re
|
|
import copy
|
|
|
|
ENV_MMCV_HOME = 'MMCV_HOME'
|
|
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
|
DEFAULT_CACHE_DIR = '~/.cache'
|
|
|
|
|
|
def _get_mmcv_home():
|
|
mmcv_home = os.path.expanduser(
|
|
os.getenv(
|
|
ENV_MMCV_HOME,
|
|
os.path.join(
|
|
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
|
|
|
mkdir_or_exist(mmcv_home)
|
|
return mmcv_home
|
|
|
|
|
|
def load_state_dict(module, state_dict, strict=False, logger=None):
|
|
"""Load state_dict to a module.
|
|
|
|
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
|
Default value for ``strict`` is set to ``False`` and the message for
|
|
param mismatch will be shown even if strict is False.
|
|
|
|
Args:
|
|
module (Module): Module that receives the state_dict.
|
|
state_dict (OrderedDict): Weights.
|
|
strict (bool): whether to strictly enforce that the keys
|
|
in :attr:`state_dict` match the keys returned by this module's
|
|
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
|
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
|
message. If not specified, print function will be used.
|
|
"""
|
|
unexpected_keys = []
|
|
all_missing_keys = []
|
|
err_msg = []
|
|
|
|
metadata = getattr(state_dict, '_metadata', None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
state_dict._metadata = metadata
|
|
|
|
# use _load_from_state_dict to enable checkpoint version control
|
|
def load(module, prefix=''):
|
|
# recursively check parallel module in case that the model has a
|
|
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
|
if is_module_wrapper(module):
|
|
module = module.module
|
|
local_metadata = {} if metadata is None else metadata.get(
|
|
prefix[:-1], {})
|
|
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
|
all_missing_keys, unexpected_keys,
|
|
err_msg)
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, prefix + name + '.')
|
|
|
|
load(module)
|
|
load = None # break load->load reference cycle
|
|
|
|
# ignore "num_batches_tracked" of BN layers
|
|
missing_keys = [
|
|
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
|
]
|
|
|
|
if unexpected_keys:
|
|
err_msg.append('unexpected key in source '
|
|
f'state_dict: {", ".join(unexpected_keys)}\n')
|
|
if missing_keys:
|
|
err_msg.append(
|
|
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
|
|
|
rank, _ = get_dist_info()
|
|
if len(err_msg) > 0 and rank == 0:
|
|
err_msg.insert(
|
|
0, 'The model and loaded state dict do not match exactly\n')
|
|
err_msg = '\n'.join(err_msg)
|
|
if strict:
|
|
raise RuntimeError(err_msg)
|
|
elif logger is not None:
|
|
logger.warning(err_msg)
|
|
else:
|
|
print(err_msg)
|
|
|
|
|
|
def load_url_dist(url, model_dir=None, map_location="cpu"):
|
|
"""In distributed setting, this function only download checkpoint at local
|
|
rank 0."""
|
|
rank, world_size = get_dist_info()
|
|
rank = int(os.environ.get('LOCAL_RANK', rank))
|
|
if rank == 0:
|
|
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
|
|
if world_size > 1:
|
|
torch.distributed.barrier()
|
|
if rank > 0:
|
|
checkpoint = model_zoo.load_url(url, model_dir=model_dir, map_location=map_location)
|
|
return checkpoint
|
|
|
|
|
|
def load_pavimodel_dist(model_path, map_location=None):
|
|
"""In distributed setting, this function only download checkpoint at local
|
|
rank 0."""
|
|
try:
|
|
from pavi import modelcloud
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Please install pavi to load checkpoint from modelcloud.')
|
|
rank, world_size = get_dist_info()
|
|
rank = int(os.environ.get('LOCAL_RANK', rank))
|
|
if rank == 0:
|
|
model = modelcloud.get(model_path)
|
|
with TemporaryDirectory() as tmp_dir:
|
|
downloaded_file = osp.join(tmp_dir, model.name)
|
|
model.download(downloaded_file)
|
|
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
|
if world_size > 1:
|
|
torch.distributed.barrier()
|
|
if rank > 0:
|
|
model = modelcloud.get(model_path)
|
|
with TemporaryDirectory() as tmp_dir:
|
|
downloaded_file = osp.join(tmp_dir, model.name)
|
|
model.download(downloaded_file)
|
|
checkpoint = torch.load(
|
|
downloaded_file, map_location=map_location)
|
|
return checkpoint
|
|
|
|
|
|
def load_fileclient_dist(filename, backend, map_location):
|
|
"""In distributed setting, this function only download checkpoint at local
|
|
rank 0."""
|
|
rank, world_size = get_dist_info()
|
|
rank = int(os.environ.get('LOCAL_RANK', rank))
|
|
allowed_backends = ['ceph']
|
|
if backend not in allowed_backends:
|
|
raise ValueError(f'Load from Backend {backend} is not supported.')
|
|
if rank == 0:
|
|
fileclient = FileClient(backend=backend)
|
|
buffer = io.BytesIO(fileclient.get(filename))
|
|
checkpoint = torch.load(buffer, map_location=map_location)
|
|
if world_size > 1:
|
|
torch.distributed.barrier()
|
|
if rank > 0:
|
|
fileclient = FileClient(backend=backend)
|
|
buffer = io.BytesIO(fileclient.get(filename))
|
|
checkpoint = torch.load(buffer, map_location=map_location)
|
|
return checkpoint
|
|
|
|
|
|
def get_torchvision_models():
|
|
model_urls = dict()
|
|
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
|
if ispkg:
|
|
continue
|
|
_zoo = import_module(f'torchvision.models.{name}')
|
|
if hasattr(_zoo, 'model_urls'):
|
|
_urls = getattr(_zoo, 'model_urls')
|
|
model_urls.update(_urls)
|
|
return model_urls
|
|
|
|
|
|
def get_external_models():
|
|
mmcv_home = _get_mmcv_home()
|
|
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
|
default_urls = load_file(default_json_path)
|
|
assert isinstance(default_urls, dict)
|
|
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
|
if osp.exists(external_json_path):
|
|
external_urls = load_file(external_json_path)
|
|
assert isinstance(external_urls, dict)
|
|
default_urls.update(external_urls)
|
|
|
|
return default_urls
|
|
|
|
|
|
def get_mmcls_models():
|
|
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
|
mmcls_urls = load_file(mmcls_json_path)
|
|
|
|
return mmcls_urls
|
|
|
|
|
|
def get_deprecated_model_names():
|
|
deprecate_json_path = osp.join(mmcv.__path__[0],
|
|
'model_zoo/deprecated.json')
|
|
deprecate_urls = load_file(deprecate_json_path)
|
|
assert isinstance(deprecate_urls, dict)
|
|
|
|
return deprecate_urls
|
|
|
|
|
|
def _process_mmcls_checkpoint(checkpoint):
|
|
state_dict = checkpoint['state_dict']
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
if k.startswith('backbone.'):
|
|
new_state_dict[k[9:]] = v
|
|
new_checkpoint = dict(state_dict=new_state_dict)
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
def _load_checkpoint(filename, map_location=None):
|
|
"""Load checkpoint from somewhere (modelzoo, file, url).
|
|
|
|
Args:
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
|
details.
|
|
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
|
|
|
Returns:
|
|
dict | OrderedDict: The loaded checkpoint. It can be either an
|
|
OrderedDict storing model weights or a dict containing other
|
|
information, which depends on the checkpoint.
|
|
"""
|
|
if filename.startswith('modelzoo://'):
|
|
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
|
'use "torchvision://" instead')
|
|
model_urls = get_torchvision_models()
|
|
model_name = filename[11:]
|
|
checkpoint = load_url_dist(model_urls[model_name])
|
|
elif filename.startswith('torchvision://'):
|
|
model_urls = get_torchvision_models()
|
|
model_name = filename[14:]
|
|
checkpoint = load_url_dist(model_urls[model_name])
|
|
elif filename.startswith('open-mmlab://'):
|
|
model_urls = get_external_models()
|
|
model_name = filename[13:]
|
|
deprecated_urls = get_deprecated_model_names()
|
|
if model_name in deprecated_urls:
|
|
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
|
f'of open-mmlab://{deprecated_urls[model_name]}')
|
|
model_name = deprecated_urls[model_name]
|
|
model_url = model_urls[model_name]
|
|
# check if is url
|
|
if model_url.startswith(('http://', 'https://')):
|
|
checkpoint = load_url_dist(model_url)
|
|
else:
|
|
filename = osp.join(_get_mmcv_home(), model_url)
|
|
if not osp.isfile(filename):
|
|
raise IOError(f'{filename} is not a checkpoint file')
|
|
checkpoint = torch.load(filename, map_location=map_location)
|
|
elif filename.startswith('mmcls://'):
|
|
model_urls = get_mmcls_models()
|
|
model_name = filename[8:]
|
|
checkpoint = load_url_dist(model_urls[model_name])
|
|
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
|
elif filename.startswith(('http://', 'https://')):
|
|
checkpoint = load_url_dist(filename)
|
|
elif filename.startswith('pavi://'):
|
|
model_path = filename[7:]
|
|
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
|
elif filename.startswith('s3://'):
|
|
checkpoint = load_fileclient_dist(
|
|
filename, backend='ceph', map_location=map_location)
|
|
else:
|
|
if not osp.isfile(filename):
|
|
raise IOError(f'{filename} is not a checkpoint file')
|
|
checkpoint = torch.load(filename, map_location=map_location)
|
|
return checkpoint
|
|
|
|
|
|
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
|
|
start_warmup_value=0, warmup_steps=-1):
|
|
warmup_schedule = np.array([])
|
|
warmup_iters = warmup_epochs * niter_per_ep
|
|
if warmup_steps > 0:
|
|
warmup_iters = warmup_steps
|
|
print("Set warmup steps = %d" % warmup_iters)
|
|
if warmup_epochs > 0:
|
|
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
|
|
|
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
|
schedule = np.array(
|
|
[final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
|
|
|
|
schedule = np.concatenate((warmup_schedule, schedule))
|
|
|
|
assert len(schedule) == epochs * niter_per_ep
|
|
return schedule
|
|
|
|
|
|
def load_checkpoint(model,
|
|
filename,
|
|
map_location='cpu',
|
|
strict=False,
|
|
logger=None,
|
|
patch_padding='pad',
|
|
part_features=None
|
|
):
|
|
"""Load checkpoint from a file or URI.
|
|
|
|
Args:
|
|
model (Module): Module to load checkpoint.
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
|
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
|
details.
|
|
map_location (str): Same as :func:`torch.load`.
|
|
strict (bool): Whether to allow different params for the model and
|
|
checkpoint.
|
|
logger (:mod:`logging.Logger` or None): The logger for error message.
|
|
patch_padding (str): 'pad' or 'bilinear' or 'bicubic', used for interpolate patch embed from 14x14 to 16x16
|
|
|
|
Returns:
|
|
dict or OrderedDict: The loaded checkpoint.
|
|
"""
|
|
checkpoint = _load_checkpoint(filename, map_location)
|
|
# OrderedDict is a subclass of dict
|
|
if not isinstance(checkpoint, dict):
|
|
raise RuntimeError(
|
|
f'No state_dict found in checkpoint file {filename}')
|
|
# get state_dict from checkpoint
|
|
if 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
elif 'model' in checkpoint:
|
|
state_dict = checkpoint['model']
|
|
elif 'module' in checkpoint:
|
|
state_dict = checkpoint['module']
|
|
else:
|
|
state_dict = checkpoint
|
|
# strip prefix of state_dict
|
|
if list(state_dict.keys())[0].startswith('module.'):
|
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
|
|
# for MoBY, load model of online branch
|
|
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
|
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
|
|
|
rank, _ = get_dist_info()
|
|
|
|
if 'patch_embed.proj.weight' in state_dict:
|
|
proj_weight = state_dict['patch_embed.proj.weight']
|
|
orig_size = proj_weight.shape[2:]
|
|
current_size = model.patch_embed.proj.weight.shape[2:]
|
|
padding_size = current_size[0] - orig_size[0]
|
|
padding_l = padding_size // 2
|
|
padding_r = padding_size - padding_l
|
|
if orig_size != current_size:
|
|
if 'pad' in patch_padding:
|
|
proj_weight = torch.nn.functional.pad(proj_weight, (padding_l, padding_r, padding_l, padding_r))
|
|
elif 'bilinear' in patch_padding:
|
|
proj_weight = torch.nn.functional.interpolate(proj_weight, size=current_size, mode='bilinear', align_corners=False)
|
|
elif 'bicubic' in patch_padding:
|
|
proj_weight = torch.nn.functional.interpolate(proj_weight, size=current_size, mode='bicubic', align_corners=False)
|
|
state_dict['patch_embed.proj.weight'] = proj_weight
|
|
|
|
if 'pos_embed' in state_dict:
|
|
pos_embed_checkpoint = state_dict['pos_embed']
|
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
H, W = model.patch_embed.patch_shape
|
|
num_patches = model.patch_embed.num_patches
|
|
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
|
# height (== width) for the checkpoint position embedding
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
if rank == 0:
|
|
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, H, W))
|
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
# only the position tokens are interpolated
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
pos_tokens = torch.nn.functional.interpolate(
|
|
pos_tokens, size=(H, W), mode='bicubic', align_corners=False)
|
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
state_dict['pos_embed'] = new_pos_embed
|
|
|
|
new_state_dict = copy.deepcopy(state_dict)
|
|
if part_features is not None:
|
|
current_keys = list(model.state_dict().keys())
|
|
for key in current_keys:
|
|
if "mlp.experts" in key:
|
|
source_key = re.sub(r'experts.\d+.', 'fc2.', key)
|
|
new_state_dict[key] = state_dict[source_key][-part_features:]
|
|
elif 'fc2' in key:
|
|
new_state_dict[key] = state_dict[key][:-part_features]
|
|
|
|
# load state_dict
|
|
load_state_dict(model, new_state_dict, strict, logger)
|
|
return checkpoint
|
|
|
|
|
|
def weights_to_cpu(state_dict):
|
|
"""Copy a model state_dict to cpu.
|
|
|
|
Args:
|
|
state_dict (OrderedDict): Model weights on GPU.
|
|
|
|
Returns:
|
|
OrderedDict: Model weights on GPU.
|
|
"""
|
|
state_dict_cpu = OrderedDict()
|
|
for key, val in state_dict.items():
|
|
state_dict_cpu[key] = val.cpu()
|
|
return state_dict_cpu
|
|
|
|
|
|
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
|
"""Saves module state to `destination` dictionary.
|
|
|
|
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
|
|
|
Args:
|
|
module (nn.Module): The module to generate state_dict.
|
|
destination (dict): A dict where state will be stored.
|
|
prefix (str): The prefix for parameters and buffers used in this
|
|
module.
|
|
"""
|
|
for name, param in module._parameters.items():
|
|
if param is not None:
|
|
destination[prefix + name] = param if keep_vars else param.detach()
|
|
for name, buf in module._buffers.items():
|
|
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
|
if buf is not None:
|
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
|
|
|
|
|
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
|
"""Returns a dictionary containing a whole state of the module.
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are
|
|
included. Keys are corresponding parameter and buffer names.
|
|
|
|
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
|
recursively check parallel module in case that the model has a complicated
|
|
structure, e.g., nn.Module(nn.Module(DDP)).
|
|
|
|
Args:
|
|
module (nn.Module): The module to generate state_dict.
|
|
destination (OrderedDict): Returned dict for the state of the
|
|
module.
|
|
prefix (str): Prefix of the key.
|
|
keep_vars (bool): Whether to keep the variable property of the
|
|
parameters. Default: False.
|
|
|
|
Returns:
|
|
dict: A dictionary containing a whole state of the module.
|
|
"""
|
|
# recursively check parallel module in case that the model has a
|
|
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
|
if is_module_wrapper(module):
|
|
module = module.module
|
|
|
|
# below is the same as torch.nn.Module.state_dict()
|
|
if destination is None:
|
|
destination = OrderedDict()
|
|
destination._metadata = OrderedDict()
|
|
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
|
version=module._version)
|
|
_save_to_state_dict(module, destination, prefix, keep_vars)
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
get_state_dict(
|
|
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
|
for hook in module._state_dict_hooks.values():
|
|
hook_result = hook(module, destination, prefix, local_metadata)
|
|
if hook_result is not None:
|
|
destination = hook_result
|
|
return destination
|
|
|
|
|
|
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
|
"""Save checkpoint to file.
|
|
|
|
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
|
``optimizer``. By default ``meta`` will contain version and time info.
|
|
|
|
Args:
|
|
model (Module): Module whose params are to be saved.
|
|
filename (str): Checkpoint filename.
|
|
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
|
meta (dict, optional): Metadata to be saved in checkpoint.
|
|
"""
|
|
if meta is None:
|
|
meta = {}
|
|
elif not isinstance(meta, dict):
|
|
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
|
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
|
|
|
if is_module_wrapper(model):
|
|
model = model.module
|
|
|
|
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
|
# save class name to the meta
|
|
meta.update(CLASSES=model.CLASSES)
|
|
|
|
checkpoint = {
|
|
'meta': meta,
|
|
'state_dict': weights_to_cpu(get_state_dict(model))
|
|
}
|
|
# save optimizer state dict in the checkpoint
|
|
if isinstance(optimizer, Optimizer):
|
|
checkpoint['optimizer'] = optimizer.state_dict()
|
|
elif isinstance(optimizer, dict):
|
|
checkpoint['optimizer'] = {}
|
|
for name, optim in optimizer.items():
|
|
checkpoint['optimizer'][name] = optim.state_dict()
|
|
|
|
if filename.startswith('pavi://'):
|
|
try:
|
|
from pavi import modelcloud
|
|
from pavi.exception import NodeNotFoundError
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Please install pavi to load checkpoint from modelcloud.')
|
|
model_path = filename[7:]
|
|
root = modelcloud.Folder()
|
|
model_dir, model_name = osp.split(model_path)
|
|
try:
|
|
model = modelcloud.get(model_dir)
|
|
except NodeNotFoundError:
|
|
model = root.create_training_model(model_dir)
|
|
with TemporaryDirectory() as tmp_dir:
|
|
checkpoint_file = osp.join(tmp_dir, model_name)
|
|
with open(checkpoint_file, 'wb') as f:
|
|
torch.save(checkpoint, f)
|
|
f.flush()
|
|
model.create_file(checkpoint_file, name=model_name)
|
|
else:
|
|
mmcv.mkdir_or_exist(osp.dirname(filename))
|
|
# immediately flush buffer
|
|
with open(filename, 'wb') as f:
|
|
torch.save(checkpoint, f)
|
|
f.flush()
|
|
|