18 changed files with 3317 additions and 2 deletions
@ -0,0 +1,177 @@ |
|||
_base_ = [ |
|||
'../../../../_base_/default_runtime.py', |
|||
'../../../../_base_/datasets/coco.py' |
|||
] |
|||
evaluation = dict(interval=1, metric='mAP', save_best='AP') |
|||
|
|||
optimizer = dict(type='AdamW', |
|||
lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, |
|||
constructor='LayerDecayOptimizerConstructor', |
|||
paramwise_cfg=dict( |
|||
num_layers=12, |
|||
layer_decay_rate=0.75, |
|||
custom_keys={ |
|||
'bias': dict(decay_multi=0.), |
|||
'pos_embed': dict(decay_mult=0.), |
|||
'relative_position_bias_table': dict(decay_mult=0.), |
|||
'norm': dict(decay_mult=0.) |
|||
} |
|||
) |
|||
) |
|||
|
|||
optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) |
|||
|
|||
# learning policy |
|||
lr_config = dict( |
|||
policy='step', |
|||
warmup='linear', |
|||
warmup_iters=500, |
|||
warmup_ratio=0.001, |
|||
step=[170, 200]) |
|||
total_epochs = 210 |
|||
target_type = 'GaussianHeatmap' |
|||
channel_cfg = dict( |
|||
num_output_channels=17, |
|||
dataset_joints=17, |
|||
dataset_channel=[ |
|||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], |
|||
], |
|||
inference_channel=[ |
|||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 |
|||
]) |
|||
|
|||
# model settings |
|||
model = dict( |
|||
type='TopDownSelf', |
|||
pretrained=None, |
|||
backbone=dict( |
|||
type='ViTSam', |
|||
img_size=(256, 192), |
|||
patch_size=16, |
|||
embed_dim=768, |
|||
depth=12, |
|||
num_heads=12, |
|||
ratio=1, |
|||
use_checkpoint=False, |
|||
mlp_ratio=4, |
|||
qkv_bias=True, |
|||
drop_path_rate=0.3, |
|||
frozen_stages=12, |
|||
freeze_attn = True, |
|||
freeze_ffn = True, |
|||
samvit_checkpoint='/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth' |
|||
), |
|||
keypoint_head=dict( |
|||
type='TopdownHeatmapSimpleHead', |
|||
in_channels=768, |
|||
num_deconv_layers=2, |
|||
num_deconv_filters=(256, 256), |
|||
num_deconv_kernels=(4, 4), |
|||
extra=dict(final_conv_kernel=1, ), |
|||
out_channels=channel_cfg['num_output_channels'], |
|||
loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), |
|||
train_cfg=dict(), |
|||
test_cfg=dict( |
|||
flip_test=True, |
|||
post_process='default', |
|||
shift_heatmap=False, |
|||
target_type=target_type, |
|||
modulate_kernel=11, |
|||
use_udp=True)) |
|||
|
|||
data_root = '/root/autodl-tmp/dataset/coco2017/' |
|||
|
|||
data_cfg = dict( |
|||
image_size=[192, 256], |
|||
heatmap_size=[48, 64], |
|||
num_output_channels=channel_cfg['num_output_channels'], |
|||
num_joints=channel_cfg['dataset_joints'], |
|||
dataset_channel=channel_cfg['dataset_channel'], |
|||
inference_channel=channel_cfg['inference_channel'], |
|||
soft_nms=False, |
|||
nms_thr=1.0, |
|||
oks_thr=0.9, |
|||
vis_thr=0.2, |
|||
use_gt_bbox=False, |
|||
det_bbox_thr=0.0, |
|||
bbox_file=f'{data_root}/person_detection_results/COCO_val2017_detections_AP_H_56_person.json', |
|||
) |
|||
|
|||
train_pipeline = [ |
|||
dict(type='LoadImageFromFile'), |
|||
dict(type='TopDownRandomFlip', flip_prob=0.5), |
|||
dict( |
|||
type='TopDownHalfBodyTransform', |
|||
num_joints_half_body=8, |
|||
prob_half_body=0.3), |
|||
dict( |
|||
type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), |
|||
# dict(type='TopDownAffine', use_udp=True), |
|||
dict(type='TopDownAffineSam', use_udp=True), |
|||
dict(type='ToTensorSam'), |
|||
dict( |
|||
type='NormalizeTensorSam', |
|||
mean=[0.485, 0.456, 0.406], |
|||
std=[0.229, 0.224, 0.225]), |
|||
dict( |
|||
type='TopDownGenerateTarget', |
|||
sigma=2, |
|||
encoding='UDP', |
|||
target_type=target_type), |
|||
dict( |
|||
type='Collect', |
|||
keys=['img', 'sam_img', 'target', 'target_weight'], |
|||
meta_keys=[ |
|||
'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', |
|||
'rotation', 'bbox_score', 'flip_pairs' |
|||
]), |
|||
] |
|||
|
|||
val_pipeline = [ |
|||
dict(type='LoadImageFromFile'), |
|||
# dict(type='TopDownAffine', use_udp=True), |
|||
dict(type='TopDownAffineSam', use_udp=True), |
|||
dict(type='ToTensorSam'), |
|||
dict( |
|||
type='NormalizeTensorSam', |
|||
mean=[0.485, 0.456, 0.406], |
|||
std=[0.229, 0.224, 0.225]), |
|||
dict( |
|||
type='Collect', |
|||
keys=['img', 'sam_img'], |
|||
meta_keys=[ |
|||
'image_file', 'center', 'scale', 'rotation', 'bbox_score', |
|||
'flip_pairs' |
|||
]), |
|||
] |
|||
|
|||
test_pipeline = val_pipeline |
|||
|
|||
data = dict( |
|||
samples_per_gpu=12, |
|||
workers_per_gpu=4, |
|||
val_dataloader=dict(samples_per_gpu=12), |
|||
test_dataloader=dict(samples_per_gpu=12), |
|||
train=dict( |
|||
type='TopDownCocoDataset', |
|||
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json', |
|||
img_prefix=f'{data_root}/train2017/', |
|||
data_cfg=data_cfg, |
|||
pipeline=train_pipeline, |
|||
dataset_info={{_base_.dataset_info}}), |
|||
val=dict( |
|||
type='TopDownCocoDataset', |
|||
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', |
|||
img_prefix=f'{data_root}/val2017/', |
|||
data_cfg=data_cfg, |
|||
pipeline=val_pipeline, |
|||
dataset_info={{_base_.dataset_info}}), |
|||
test=dict( |
|||
type='TopDownCocoDataset', |
|||
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json', |
|||
img_prefix=f'{data_root}/val2017/', |
|||
data_cfg=data_cfg, |
|||
pipeline=test_pipeline, |
|||
dataset_info={{_base_.dataset_info}}), |
|||
) |
|||
|
@ -0,0 +1 @@ |
|||
../../configs |
@ -0,0 +1 @@ |
|||
../../demo |
@ -0,0 +1 @@ |
|||
../../model-index.yml |
@ -0,0 +1 @@ |
|||
../../tools |
@ -0,0 +1,76 @@ |
|||
# Copyright (c) OpenMMLab. All rights reserved. |
|||
import warnings |
|||
from collections.abc import Sequence |
|||
|
|||
import mmcv |
|||
import numpy as np |
|||
from mmcv.parallel import DataContainer as DC |
|||
from mmcv.utils import build_from_cfg |
|||
from numpy import random |
|||
from torchvision.transforms import functional as F |
|||
|
|||
from ..builder import PIPELINES |
|||
|
|||
try: |
|||
import albumentations |
|||
except ImportError: |
|||
albumentations = None |
|||
|
|||
|
|||
@PIPELINES.register_module() |
|||
class ToTensorSam: |
|||
"""Transform image to Tensor. |
|||
|
|||
Required key: 'img'. Modifies key: 'img'. |
|||
|
|||
Args: |
|||
results (dict): contain all information about training. |
|||
""" |
|||
|
|||
def __call__(self, results): |
|||
if isinstance(results['img'], (list, tuple)): |
|||
results['img'] = [F.to_tensor(img) for img in results['img']] |
|||
# 修改 |
|||
results['sam_img'] = [F.to_tensor(sam_img) for sam_img in results['sam_img']] |
|||
else: |
|||
results['img'] = F.to_tensor(results['img']) |
|||
# 修改 |
|||
results['sam_img'] = F.to_tensor(results['sam_img']) |
|||
|
|||
return results |
|||
|
|||
|
|||
@PIPELINES.register_module() |
|||
class NormalizeTensorSam: |
|||
"""Normalize the Tensor image (CxHxW), with mean and std. |
|||
|
|||
Required key: 'img'. Modifies key: 'img'. |
|||
|
|||
Args: |
|||
mean (list[float]): Mean values of 3 channels. |
|||
std (list[float]): Std values of 3 channels. |
|||
""" |
|||
|
|||
def __init__(self, mean, std): |
|||
self.mean = mean |
|||
self.std = std |
|||
|
|||
def __call__(self, results): |
|||
if isinstance(results['img'], (list, tuple)): |
|||
results['img'] = [ |
|||
F.normalize(img, mean=self.mean, std=self.std) |
|||
for img in results['img'] |
|||
] |
|||
# 修改 |
|||
results['sam_img'] = [ |
|||
F.normalize(sam_img, mean=self.mean, std=self.std) |
|||
for sam_img in results['sam_img'] |
|||
] |
|||
else: |
|||
results['img'] = F.normalize( |
|||
results['img'], mean=self.mean, std=self.std) |
|||
# 修改 |
|||
results['sam_img'] = F.normalize( |
|||
results['sam_img'], mean=self.mean, std=self.std) |
|||
|
|||
return results |
@ -0,0 +1,113 @@ |
|||
import cv2 |
|||
import numpy as np |
|||
|
|||
from mmpose.core.post_processing import (affine_transform, fliplr_joints, |
|||
get_affine_transform, get_warp_matrix, |
|||
warp_affine_joints) |
|||
from mmpose.datasets.builder import PIPELINES |
|||
|
|||
@PIPELINES.register_module() |
|||
class TopDownAffineSam: |
|||
"""Affine transform the image to make input. |
|||
|
|||
Required keys:'img', 'joints_3d', 'joints_3d_visible', 'ann_info','scale', |
|||
'rotation' and 'center'. |
|||
|
|||
Modified keys:'img', 'joints_3d', and 'joints_3d_visible'. |
|||
|
|||
Args: |
|||
use_udp (bool): To use unbiased data processing. |
|||
Paper ref: Huang et al. The Devil is in the Details: Delving into |
|||
Unbiased Data Processing for Human Pose Estimation (CVPR 2020). |
|||
""" |
|||
|
|||
def __init__(self, use_udp=False): |
|||
self.use_udp = use_udp |
|||
|
|||
def __call__(self, results): |
|||
image_size = results['ann_info']['image_size'] |
|||
# 修改 |
|||
sam_image_size = np.array([1024, 1024]) |
|||
|
|||
img = results['img'] |
|||
joints_3d = results['joints_3d'] |
|||
joints_3d_visible = results['joints_3d_visible'] |
|||
c = results['center'] |
|||
s = results['scale'] |
|||
r = results['rotation'] |
|||
# 修改 |
|||
sam_img = img |
|||
|
|||
if self.use_udp: |
|||
trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0) |
|||
# 修改 |
|||
sam_trans = get_warp_matrix(r, c * 2.0, sam_image_size - 1.0, s * 200.0) |
|||
if not isinstance(img, list): |
|||
img = cv2.warpAffine( |
|||
img, |
|||
trans, (int(image_size[0]), int(image_size[1])), |
|||
flags=cv2.INTER_LINEAR) |
|||
# 修改 |
|||
sam_img = cv2.warpAffine( |
|||
sam_img, |
|||
sam_trans, (int(sam_image_size[0]), int(sam_image_size[1])), |
|||
flags=cv2.INTER_LINEAR) |
|||
else: |
|||
img = [ |
|||
cv2.warpAffine( |
|||
i, |
|||
trans, (int(image_size[0]), int(image_size[1])), |
|||
flags=cv2.INTER_LINEAR) for i in img |
|||
] |
|||
# 修改 |
|||
sam_img = [ |
|||
cv2.warpAffine( |
|||
i, |
|||
sam_trans, (int(sam_image_size[0]), int(sam_image_size[1])), |
|||
flags=cv2.INTER_LINEAR) for i in sam_img |
|||
] |
|||
|
|||
joints_3d[:, 0:2] = \ |
|||
warp_affine_joints(joints_3d[:, 0:2].copy(), trans) |
|||
|
|||
else: |
|||
trans = get_affine_transform(c, s, r, image_size) |
|||
# 修改 |
|||
sam_trans = get_affine_transform(c, s, r, sam_image_size) |
|||
if not isinstance(img, list): |
|||
img = cv2.warpAffine( |
|||
img, |
|||
trans, (int(image_size[0]), int(image_size[1])), |
|||
flags=cv2.INTER_LINEAR) |
|||
|
|||
# 修改 |
|||
sam_img = cv2.warpAffine( |
|||
sam_img, |
|||
sam_trans, (int(sam_image_size[0]), int(sam_image_size[1])), |
|||
flags=cv2.INTER_LINEAR) |
|||
else: |
|||
img = [ |
|||
cv2.warpAffine( |
|||
i, |
|||
trans, (int(image_size[0]), int(image_size[1])), |
|||
flags=cv2.INTER_LINEAR) for i in img |
|||
] |
|||
# 修改 |
|||
sam_img = [ |
|||
cv2.warpAffine( |
|||
i, |
|||
sam_trans, (int(sam_image_size[0]), int(sam_image_size[1])), |
|||
flags=cv2.INTER_LINEAR) for i in sam_img |
|||
] |
|||
|
|||
for i in range(results['ann_info']['num_joints']): |
|||
if joints_3d_visible[i, 0] > 0.0: |
|||
joints_3d[i, |
|||
0:2] = affine_transform(joints_3d[i, 0:2], trans) |
|||
|
|||
results['img'] = img |
|||
results['sam_img'] = sam_img |
|||
results['joints_3d'] = joints_3d |
|||
results['joints_3d_visible'] = joints_3d_visible |
|||
|
|||
return results |
@ -0,0 +1 @@ |
|||
from .image_encoder import build_vit_sam |
@ -0,0 +1,477 @@ |
|||
# -------------------------------------------------------------------- |
|||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
|||
# All rights reserved. |
|||
|
|||
# This source code is licensed under the license found in the |
|||
# LICENSE file in the root directory of this source tree. |
|||
# -------------------------------------------------------------------- |
|||
|
|||
from typing import Optional, Tuple, Type |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn |
|||
import torch.nn.functional as F |
|||
|
|||
from functools import partial |
|||
|
|||
|
|||
# ---------------------- Vision Transformer of Segment-Anything ---------------------- |
|||
class ImageEncoderViT(nn.Module): |
|||
""" |
|||
We remove the neck which used in the Segment-Anything. |
|||
""" |
|||
def __init__(self, |
|||
img_size : int = 1024, |
|||
patch_size : int = 16, |
|||
in_chans : int = 3, |
|||
embed_dim : int = 768, |
|||
depth : int = 12, |
|||
num_heads : int = 12, |
|||
mlp_ratio : float = 4.0, |
|||
qkv_bias : bool = True, |
|||
norm_layer : Type[nn.Module] = nn.LayerNorm, |
|||
act_layer : Type[nn.Module] = nn.GELU, |
|||
use_abs_pos : bool = True, |
|||
use_rel_pos : bool = True, |
|||
window_size : int = 0, |
|||
global_attn_indexes : Tuple[int, ...] = (), |
|||
checkpoint = None |
|||
) -> None: |
|||
super().__init__() |
|||
self.img_size = img_size |
|||
self.patch_size = patch_size |
|||
self.embed_dim = embed_dim |
|||
self.num_patches = (img_size // patch_size) ** 2 |
|||
# self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) |
|||
self.pos_embed: Optional[nn.Parameter] = None |
|||
self.checkpoint = checkpoint |
|||
if use_abs_pos: |
|||
# Initialize absolute positional embedding with pretrain image size. |
|||
self.pos_embed = nn.Parameter( |
|||
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) |
|||
) |
|||
|
|||
# ------------ Model parameters ------------ |
|||
## Patch embedding layer |
|||
self.patch_embed = PatchEmbed( |
|||
kernel_size=(patch_size, patch_size), |
|||
stride=(patch_size, patch_size), |
|||
in_chans=in_chans, |
|||
embed_dim=embed_dim, |
|||
) |
|||
|
|||
## ViT blocks |
|||
self.blocks = nn.ModuleList() |
|||
for i in range(depth): |
|||
block = Block(dim = embed_dim, |
|||
num_heads = num_heads, |
|||
mlp_ratio = mlp_ratio, |
|||
qkv_bias = qkv_bias, |
|||
norm_layer = norm_layer, |
|||
act_layer = act_layer, |
|||
use_rel_pos = use_rel_pos, |
|||
window_size = window_size if i not in global_attn_indexes else 0, |
|||
input_size = (img_size // patch_size, img_size // patch_size), |
|||
) |
|||
self.blocks.append(block) |
|||
|
|||
self.load_pretrained() |
|||
|
|||
def load_pretrained(self): |
|||
if self.checkpoint is not None: |
|||
print('Loading SAM pretrained weight from : {}'.format(self.checkpoint)) |
|||
# checkpoint state dict |
|||
checkpoint_state_dict = torch.load(self.checkpoint, map_location="cpu") |
|||
# model state dict |
|||
model_state_dict = self.state_dict() |
|||
encoder_state_dict = {} |
|||
# check |
|||
for k in list(checkpoint_state_dict.keys()): |
|||
if "image_encoder" in k and k[14:] in model_state_dict: |
|||
shape_model = tuple(model_state_dict[k[14:]].shape) |
|||
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
|||
if shape_model == shape_checkpoint or "pos_embed" in k: |
|||
encoder_state_dict[k[14:]] = checkpoint_state_dict[k] |
|||
else: |
|||
print("Shape unmatch: ", k) |
|||
|
|||
# interpolate position embedding |
|||
# interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) |
|||
interpolate_pos_embed(self, encoder_state_dict,) |
|||
|
|||
# load the weight |
|||
self.load_state_dict(encoder_state_dict, strict=False) |
|||
else: |
|||
print('No SAM pretrained.') |
|||
|
|||
# @torch.no_grad() |
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
# with torch.no_grad(): |
|||
|
|||
x = self.patch_embed(x) |
|||
if self.pos_embed is not None: |
|||
x = x + self.pos_embed |
|||
|
|||
for blk in self.blocks: |
|||
x = blk(x) |
|||
|
|||
# [B, H, W, C] -> [B, N, C] |
|||
return x.flatten(1, 2) |
|||
|
|||
|
|||
# ---------------------- Model modules ---------------------- |
|||
class MLPBlock(nn.Module): |
|||
def __init__(self, |
|||
embedding_dim: int, |
|||
mlp_dim: int, |
|||
act: Type[nn.Module] = nn.GELU, |
|||
) -> None: |
|||
super().__init__() |
|||
self.lin1 = nn.Linear(embedding_dim, mlp_dim) |
|||
self.lin2 = nn.Linear(mlp_dim, embedding_dim) |
|||
self.act = act() |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
return self.lin2(self.act(self.lin1(x))) |
|||
|
|||
class LayerNorm2d(nn.Module): |
|||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|||
super().__init__() |
|||
self.weight = nn.Parameter(torch.ones(num_channels)) |
|||
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|||
self.eps = eps |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
u = x.mean(1, keepdim=True) |
|||
s = (x - u).pow(2).mean(1, keepdim=True) |
|||
x = (x - u) / torch.sqrt(s + self.eps) |
|||
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|||
|
|||
return x |
|||
|
|||
class Block(nn.Module): |
|||
def __init__(self, |
|||
dim : int, |
|||
num_heads : int, |
|||
mlp_ratio : float = 4.0, |
|||
qkv_bias : bool = True, |
|||
norm_layer : Type[nn.Module] = nn.LayerNorm, |
|||
act_layer : Type[nn.Module] = nn.GELU, |
|||
use_rel_pos : bool = False, |
|||
window_size : int = 0, |
|||
input_size : Optional[Tuple[int, int]] = None, |
|||
) -> None: |
|||
super().__init__() |
|||
# -------------- Basic parameters -------------- |
|||
self.window_size = window_size |
|||
# -------------- Model parameters -------------- |
|||
self.norm1 = norm_layer(dim) |
|||
self.attn = Attention(dim = dim, |
|||
num_heads = num_heads, |
|||
qkv_bias = qkv_bias, |
|||
use_rel_pos = use_rel_pos, |
|||
input_size = input_size if window_size == 0 else (window_size, window_size), |
|||
) |
|||
self.norm2 = norm_layer(dim) |
|||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
shortcut = x |
|||
x = self.norm1(x) |
|||
# Window partition |
|||
if self.window_size > 0: |
|||
H, W = x.shape[1], x.shape[2] |
|||
x, pad_hw = window_partition(x, self.window_size) |
|||
|
|||
x = self.attn(x) |
|||
# Reverse window partition |
|||
if self.window_size > 0: |
|||
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) |
|||
|
|||
x = shortcut + x |
|||
x = x + self.mlp(self.norm2(x)) |
|||
|
|||
return x |
|||
|
|||
class Attention(nn.Module): |
|||
def __init__(self, |
|||
dim: int, |
|||
num_heads: int = 8, |
|||
qkv_bias: bool = True, |
|||
use_rel_pos: bool = False, |
|||
input_size: Optional[Tuple[int, int]] = None, |
|||
) -> None: |
|||
super().__init__() |
|||
# -------------- Basic parameters -------------- |
|||
self.num_heads = num_heads |
|||
head_dim = dim // num_heads |
|||
self.scale = head_dim**-0.5 |
|||
self.use_rel_pos = use_rel_pos |
|||
if self.use_rel_pos: |
|||
assert ( |
|||
input_size is not None |
|||
), "Input size must be provided if using relative positional encoding." |
|||
# initialize relative positional embeddings |
|||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) |
|||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) |
|||
|
|||
# -------------- Model parameters -------------- |
|||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|||
self.proj = nn.Linear(dim, dim) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
B, H, W, _ = x.shape |
|||
# qkv with shape (3, B, nHead, H * W, C) |
|||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
|||
# q, k, v with shape (B * nHead, H * W, C) |
|||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) |
|||
|
|||
attn = (q * self.scale) @ k.transpose(-2, -1) |
|||
|
|||
if self.use_rel_pos: |
|||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) |
|||
|
|||
attn = attn.softmax(dim=-1) |
|||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) |
|||
x = self.proj(x) |
|||
|
|||
return x |
|||
|
|||
class PatchEmbed(nn.Module): |
|||
def __init__(self, |
|||
kernel_size : Tuple[int, int] = (16, 16), |
|||
stride : Tuple[int, int] = (16, 16), |
|||
padding : Tuple[int, int] = (0, 0), |
|||
in_chans : int = 3, |
|||
embed_dim : int = 768, |
|||
) -> None: |
|||
super().__init__() |
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
x = self.proj(x) |
|||
# [B, C, H, W] -> [B, H, W, C] |
|||
x = x.permute(0, 2, 3, 1) |
|||
|
|||
return x |
|||
|
|||
|
|||
# ---------------------- Model functions ---------------------- |
|||
def window_partition(x: torch.Tensor, |
|||
window_size: int, |
|||
) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|||
""" |
|||
Partition into non-overlapping windows with padding if needed. |
|||
Args: |
|||
x (tensor): input tokens with [B, H, W, C]. |
|||
window_size (int): window size. |
|||
|
|||
Returns: |
|||
windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
|||
(Hp, Wp): padded height and width before partition |
|||
""" |
|||
B, H, W, C = x.shape |
|||
|
|||
pad_h = (window_size - H % window_size) % window_size |
|||
pad_w = (window_size - W % window_size) % window_size |
|||
if pad_h > 0 or pad_w > 0: |
|||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
|||
Hp, Wp = H + pad_h, W + pad_w |
|||
|
|||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) |
|||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|||
|
|||
return windows, (Hp, Wp) |
|||
|
|||
def window_unpartition(windows: torch.Tensor, |
|||
window_size: int, |
|||
pad_hw: Tuple[int, int], |
|||
hw: Tuple[int, int], |
|||
) -> torch.Tensor: |
|||
""" |
|||
Window unpartition into original sequences and removing padding. |
|||
Args: |
|||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
|||
window_size (int): window size. |
|||
pad_hw (Tuple): padded height and width (Hp, Wp). |
|||
hw (Tuple): original height and width (H, W) before padding. |
|||
|
|||
Returns: |
|||
x: unpartitioned sequences with [B, H, W, C]. |
|||
""" |
|||
Hp, Wp = pad_hw |
|||
H, W = hw |
|||
B = windows.shape[0] // (Hp * Wp // window_size // window_size) |
|||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) |
|||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) |
|||
|
|||
if Hp > H or Wp > W: |
|||
x = x[:, :H, :W, :].contiguous() |
|||
|
|||
return x |
|||
|
|||
def get_rel_pos(q_size: int, |
|||
k_size: int, |
|||
rel_pos: torch.Tensor, |
|||
)-> torch.Tensor: |
|||
""" |
|||
Get relative positional embeddings according to the relative positions of |
|||
query and key sizes. |
|||
Args: |
|||
q_size (int): size of query q. |
|||
k_size (int): size of key k. |
|||
rel_pos (Tensor): relative position embeddings (L, C). |
|||
|
|||
Returns: |
|||
Extracted positional embeddings according to relative positions. |
|||
""" |
|||
max_rel_dist = int(2 * max(q_size, k_size) - 1) |
|||
# Interpolate rel pos if needed. |
|||
if rel_pos.shape[0] != max_rel_dist: |
|||
# Interpolate rel pos. |
|||
rel_pos_resized = F.interpolate( |
|||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), |
|||
size=max_rel_dist, |
|||
mode="linear", |
|||
) |
|||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) |
|||
else: |
|||
rel_pos_resized = rel_pos |
|||
|
|||
# Scale the coords with short length if shapes for q and k are different. |
|||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) |
|||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) |
|||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) |
|||
|
|||
return rel_pos_resized[relative_coords.long()] |
|||
|
|||
def add_decomposed_rel_pos(attn : torch.Tensor, |
|||
q : torch.Tensor, |
|||
rel_pos_h : torch.Tensor, |
|||
rel_pos_w : torch.Tensor, |
|||
q_size : Tuple[int, int], |
|||
k_size : Tuple[int, int], |
|||
) -> torch.Tensor: |
|||
q_h, q_w = q_size |
|||
k_h, k_w = k_size |
|||
Rh = get_rel_pos(q_h, k_h, rel_pos_h) |
|||
Rw = get_rel_pos(q_w, k_w, rel_pos_w) |
|||
|
|||
B, _, dim = q.shape |
|||
r_q = q.reshape(B, q_h, q_w, dim) |
|||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) |
|||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) |
|||
|
|||
attn = ( |
|||
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] |
|||
).view(B, q_h * q_w, k_h * k_w) |
|||
|
|||
return attn |
|||
|
|||
def interpolate_pos_embed(model, checkpoint_model): |
|||
if 'pos_embed' in checkpoint_model: |
|||
# Pos embed from checkpoint |
|||
pos_embed_checkpoint = checkpoint_model['pos_embed'] |
|||
embedding_size = pos_embed_checkpoint.shape[-1] |
|||
# Pos embed from model |
|||
pos_embed_model = model.pos_embed |
|||
num_patches = model.num_patches |
|||
# [B, H, W, C] -> [B, N, C] |
|||
pos_embed_checkpoint = pos_embed_checkpoint.flatten(1, 2) |
|||
pos_embed_model = pos_embed_model.flatten(1, 2) |
|||
|
|||
orig_num_postions = pos_embed_model.shape[-2] |
|||
num_extra_tokens = orig_num_postions - num_patches |
|||
|
|||
# height (== width) for the checkpoint position embedding |
|||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|||
new_size = int(num_patches ** 0.5) |
|||
|
|||
# height (== width) for the new position embedding |
|||
# class_token and dist_token are kept unchanged |
|||
if orig_size != new_size: |
|||
print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
|||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|||
# only the position tokens are interpolated |
|||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
|||
pos_tokens = torch.nn.functional.interpolate(pos_tokens, |
|||
size=(new_size,new_size), |
|||
mode='bicubic', |
|||
align_corners=False) |
|||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|||
new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) |
|||
checkpoint_model['pos_embed'] = new_pos_embed |
|||
|
|||
|
|||
# ------------------------ Model Functions ------------------------ |
|||
def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None): |
|||
if model_name == "vit_b": |
|||
return ImageEncoderViT(img_size=img_size, |
|||
patch_size=patch_size, |
|||
in_chans=img_dim, |
|||
embed_dim=768, |
|||
depth=12, |
|||
num_heads=12, |
|||
mlp_ratio=4.0, |
|||
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|||
global_attn_indexes=[2, 5, 8, 11], |
|||
window_size=14, |
|||
checkpoint=checkpoint, |
|||
) |
|||
if model_name == "vit_l": |
|||
return ImageEncoderViT(img_size=img_size, |
|||
patch_size=patch_size, |
|||
in_chans=img_dim, |
|||
embed_dim=1024, |
|||
depth=24, |
|||
num_heads=16, |
|||
mlp_ratio=4.0, |
|||
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|||
global_attn_indexes=[5, 11, 17, 23], |
|||
window_size=14, |
|||
checkpoint=checkpoint, |
|||
) |
|||
if model_name == "vit_h": |
|||
return ImageEncoderViT(img_size=img_size, |
|||
patch_size=patch_size, |
|||
in_chans=img_dim, |
|||
embed_dim=1280, |
|||
depth=32, |
|||
num_heads=16, |
|||
mlp_ratio=4.0, |
|||
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|||
global_attn_indexes=[7, 15, 23, 31], |
|||
window_size=14, |
|||
checkpoint=checkpoint, |
|||
) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
import torch |
|||
from thop import profile |
|||
|
|||
# Prepare an image as the input |
|||
bs, c, h, w = 2, 3, 1024, 1024 |
|||
x = torch.randn(bs, c, h, w) |
|||
patch_size = 16 |
|||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') |
|||
|
|||
# Build model |
|||
model = build_vit_sam(model_name='vit_b', checkpoint="/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth") |
|||
if torch.cuda.is_available(): |
|||
x = x.to(device) |
|||
model = model.to(device) |
|||
|
|||
# Inference |
|||
outputs = model(x) |
|||
print(outputs.shape) |
|||
|
|||
# Compute FLOPs & Params |
|||
print('==============================') |
|||
model.eval() |
|||
flops, params = profile(model, inputs=(x, ), verbose=False) |
|||
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) |
|||
print('Params : {:.2f} M'.format(params / 1e6)) |
@ -0,0 +1,483 @@ |
|||
# Copyright (c) OpenMMLab. All rights reserved. |
|||
import math |
|||
|
|||
import torch |
|||
from functools import partial |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
import torch.utils.checkpoint as checkpoint |
|||
|
|||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_ |
|||
|
|||
from ..builder import BACKBONES |
|||
from .base_backbone import BaseBackbone |
|||
|
|||
from .sam_vit import build_vit_sam |
|||
|
|||
|
|||
def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True): |
|||
""" |
|||
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token |
|||
dimension for the original embeddings. |
|||
Args: |
|||
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). |
|||
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. |
|||
hw (Tuple): size of input image tokens. |
|||
|
|||
Returns: |
|||
Absolute positional embeddings after processing with shape (1, H, W, C) |
|||
""" |
|||
cls_token = None |
|||
B, L, C = abs_pos.shape |
|||
if has_cls_token: |
|||
cls_token = abs_pos[:, 0:1] |
|||
abs_pos = abs_pos[:, 1:] |
|||
|
|||
if ori_h != h or ori_w != w: |
|||
new_abs_pos = F.interpolate( |
|||
abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2), |
|||
size=(h, w), |
|||
mode="bicubic", |
|||
align_corners=False, |
|||
).permute(0, 2, 3, 1).reshape(B, -1, C) |
|||
|
|||
else: |
|||
new_abs_pos = abs_pos |
|||
|
|||
if cls_token is not None: |
|||
new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1) |
|||
return new_abs_pos |
|||
|
|||
class DropPath(nn.Module): |
|||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|||
""" |
|||
def __init__(self, drop_prob=None): |
|||
super(DropPath, self).__init__() |
|||
self.drop_prob = drop_prob |
|||
|
|||
def forward(self, x): |
|||
return drop_path(x, self.drop_prob, self.training) |
|||
|
|||
def extra_repr(self): |
|||
return 'p={}'.format(self.drop_prob) |
|||
|
|||
class Mlp(nn.Module): |
|||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|||
super().__init__() |
|||
out_features = out_features or in_features |
|||
hidden_features = hidden_features or in_features |
|||
self.fc1 = nn.Linear(in_features, hidden_features) |
|||
self.act = act_layer() |
|||
self.fc2 = nn.Linear(hidden_features, out_features) |
|||
self.drop = nn.Dropout(drop) |
|||
|
|||
def forward(self, x): |
|||
x = self.fc1(x) |
|||
x = self.act(x) |
|||
x = self.fc2(x) |
|||
x = self.drop(x) |
|||
return x |
|||
|
|||
class Attention(nn.Module): |
|||
def __init__( |
|||
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., |
|||
proj_drop=0., attn_head_dim=None,): |
|||
super().__init__() |
|||
self.num_heads = num_heads |
|||
head_dim = dim // num_heads |
|||
self.dim = dim |
|||
|
|||
if attn_head_dim is not None: |
|||
head_dim = attn_head_dim |
|||
all_head_dim = head_dim * self.num_heads |
|||
|
|||
self.scale = qk_scale or head_dim ** -0.5 |
|||
|
|||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) |
|||
|
|||
self.attn_drop = nn.Dropout(attn_drop) |
|||
self.proj = nn.Linear(all_head_dim, dim) |
|||
self.proj_drop = nn.Dropout(proj_drop) |
|||
|
|||
def forward(self, x): |
|||
B, N, C = x.shape |
|||
qkv = self.qkv(x) |
|||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
|||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
|||
|
|||
q = q * self.scale |
|||
attn = (q @ k.transpose(-2, -1)) |
|||
|
|||
attn = attn.softmax(dim=-1) |
|||
attn = self.attn_drop(attn) |
|||
|
|||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
|||
x = self.proj(x) |
|||
x = self.proj_drop(x) |
|||
|
|||
return x |
|||
|
|||
class Block(nn.Module): |
|||
|
|||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, |
|||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, |
|||
norm_layer=nn.LayerNorm, attn_head_dim=None |
|||
): |
|||
super().__init__() |
|||
|
|||
self.norm1 = norm_layer(dim) |
|||
self.attn = Attention( |
|||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim |
|||
) |
|||
|
|||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
|||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|||
self.norm2 = norm_layer(dim) |
|||
mlp_hidden_dim = int(dim * mlp_ratio) |
|||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|||
|
|||
def forward(self, x): |
|||
x = x + self.drop_path(self.attn(self.norm1(x))) |
|||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|||
return x |
|||
|
|||
|
|||
class PatchEmbed(nn.Module): |
|||
""" Image to Patch Embedding |
|||
""" |
|||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): |
|||
super().__init__() |
|||
img_size = to_2tuple(img_size) |
|||
patch_size = to_2tuple(patch_size) |
|||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2) |
|||
self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio)) |
|||
self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1])) |
|||
self.img_size = img_size |
|||
self.patch_size = patch_size |
|||
self.num_patches = num_patches |
|||
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1)) |
|||
|
|||
def forward(self, x, **kwargs): |
|||
B, C, H, W = x.shape |
|||
x = self.proj(x) |
|||
Hp, Wp = x.shape[2], x.shape[3] |
|||
|
|||
x = x.flatten(2).transpose(1, 2) |
|||
return x, (Hp, Wp) |
|||
|
|||
|
|||
class HybridEmbed(nn.Module): |
|||
""" CNN Feature Map Embedding |
|||
Extract feature map from CNN, flatten, project to embedding dim. |
|||
""" |
|||
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): |
|||
super().__init__() |
|||
assert isinstance(backbone, nn.Module) |
|||
img_size = to_2tuple(img_size) |
|||
self.img_size = img_size |
|||
self.backbone = backbone |
|||
if feature_size is None: |
|||
with torch.no_grad(): |
|||
training = backbone.training |
|||
if training: |
|||
backbone.eval() |
|||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] |
|||
feature_size = o.shape[-2:] |
|||
feature_dim = o.shape[1] |
|||
backbone.train(training) |
|||
else: |
|||
feature_size = to_2tuple(feature_size) |
|||
feature_dim = self.backbone.feature_info.channels()[-1] |
|||
self.num_patches = feature_size[0] * feature_size[1] |
|||
self.proj = nn.Linear(feature_dim, embed_dim) |
|||
|
|||
def forward(self, x): |
|||
x = self.backbone(x)[-1] |
|||
x = x.flatten(2).transpose(1, 2) |
|||
x = self.proj(x) |
|||
return x |
|||
|
|||
class Cross_Attention(nn.Module): |
|||
def __init__(self, dim, num_heads=12, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|||
super().__init__() |
|||
self.num_heads = num_heads |
|||
head_dim = dim // num_heads |
|||
self.scale = qk_scale or head_dim ** -0.5 |
|||
|
|||
self.self_attn = Attention( |
|||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
attn_drop=attn_drop, proj_drop=0.) |
|||
|
|||
self.linear_q = nn.Linear(dim, dim, bias=qkv_bias) |
|||
self.linear_k = nn.Linear(dim, dim, bias=qkv_bias) |
|||
self.linear_v = nn.Linear(dim, dim, bias=qkv_bias) |
|||
|
|||
self.attn_drop = nn.Dropout(attn_drop) |
|||
self.proj = nn.Linear(dim, dim) |
|||
self.proj_drop = nn.Dropout(proj_drop) |
|||
|
|||
def forward(self, x_1, x_2, x_3): |
|||
B, N, C = x_1.shape # q |
|||
B, N_1, C = x_2.shape # k, v |
|||
|
|||
q = self.linear_q(x_1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # (B, num_heads, N, C//num_heads) |
|||
k = self.linear_k(x_2).reshape(B, N_1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # (B, num_heads, N_1, C//num_heads) |
|||
v = self.linear_v(x_3).reshape(B, N_1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # (B, num_heads, N_1, C//num_heads) |
|||
|
|||
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, num_heads, N, N_1) |
|||
attn = attn.softmax(dim=-1) |
|||
attn = self.attn_drop(attn) |
|||
|
|||
# import matplotlib.pyplot as plt |
|||
# import seaborn as sns |
|||
|
|||
# attn_map = attn[0][0].cpu().detach().numpy() |
|||
# plt.figure(figsize=(20, 10)) |
|||
# sns.heatmap(attn_map, annot=True, fmt='.2f', cmap='coolwarm') |
|||
|
|||
# plt.title('Cross Attention Map') |
|||
# plt.xlabel('N_1') |
|||
# plt.ylabel('N') |
|||
|
|||
# plt.savefig('/home/fhw/code/ViTPose/test/cross_attn_map.png') |
|||
# plt.close() |
|||
|
|||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B, N, C) |
|||
x = self.proj(x) |
|||
x = self.proj_drop(x) |
|||
return x |
|||
|
|||
class CustomAttentionFFN(nn.Module): |
|||
def __init__(self, dim, num_heads=12, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|||
super().__init__() |
|||
self.self_attn = Attention( |
|||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
attn_drop=attn_drop, proj_drop=proj_drop) |
|||
|
|||
self.cross_attn = Cross_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|||
qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) |
|||
|
|||
self.ffn = nn.Sequential( |
|||
nn.Linear(dim, dim * 4), |
|||
nn.GELU(), |
|||
nn.Linear(dim * 4, dim) |
|||
) |
|||
self.norm1 = nn.LayerNorm(dim) |
|||
self.norm2 = nn.LayerNorm(dim) |
|||
self.norm3 = nn.LayerNorm(dim) |
|||
|
|||
def forward(self, x1, x2): |
|||
x1 = self.norm1(x1 + self.self_attn(x1)) |
|||
|
|||
x1 = self.norm2(x1 + self.cross_attn(x1, x2, x2)) |
|||
|
|||
x1 = self.norm3(x1 + self.ffn(x1)) |
|||
|
|||
return x1 |
|||
|
|||
@BACKBONES.register_module() |
|||
class ViTSam(BaseBackbone): |
|||
|
|||
def __init__(self, |
|||
img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, |
|||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., |
|||
drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False, |
|||
frozen_stages=-1, ratio=1, last_norm=True, |
|||
patch_padding='pad', freeze_attn=False, freeze_ffn=False, samvit_checkpoint=None |
|||
): |
|||
# Protect mutable default arguments |
|||
super(ViTSam, self).__init__() |
|||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|||
self.num_classes = num_classes |
|||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
|||
self.frozen_stages = frozen_stages |
|||
self.use_checkpoint = use_checkpoint |
|||
self.patch_padding = patch_padding |
|||
self.freeze_attn = freeze_attn |
|||
self.freeze_ffn = freeze_ffn |
|||
self.depth = depth |
|||
|
|||
if hybrid_backbone is not None: |
|||
self.patch_embed = HybridEmbed( |
|||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) |
|||
else: |
|||
self.patch_embed = PatchEmbed( |
|||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio) |
|||
num_patches = self.patch_embed.num_patches |
|||
|
|||
# since the pretraining model has class token |
|||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
|||
|
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
|||
|
|||
self.blocks = nn.ModuleList([ |
|||
Block( |
|||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, |
|||
) |
|||
for i in range(depth)]) |
|||
|
|||
self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() |
|||
|
|||
if self.pos_embed is not None: |
|||
trunc_normal_(self.pos_embed, std=.02) |
|||
|
|||
self._freeze_stages() |
|||
|
|||
# ======================== SAM-ViT ======================== |
|||
self.sam_vit = build_vit_sam(model_name='vit_b', checkpoint=samvit_checkpoint) |
|||
self.sam_vit.eval() |
|||
for param in self.sam_vit.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
# self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|||
# qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) |
|||
|
|||
self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|||
qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) |
|||
|
|||
def _freeze_stages(self): |
|||
"""Freeze parameters.""" |
|||
if self.frozen_stages >= 0: |
|||
self.patch_embed.eval() |
|||
for param in self.patch_embed.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
for i in range(0, self.frozen_stages): |
|||
m = self.blocks[i] |
|||
m.eval() |
|||
for param in m.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
if self.freeze_attn: |
|||
for i in range(0, self.depth): |
|||
m = self.blocks[i] |
|||
m.attn.eval() |
|||
m.norm1.eval() |
|||
for param in m.attn.parameters(): |
|||
param.requires_grad = False |
|||
for param in m.norm1.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
if self.freeze_ffn: |
|||
self.pos_embed.requires_grad = False |
|||
self.patch_embed.eval() |
|||
for param in self.patch_embed.parameters(): |
|||
param.requires_grad = False |
|||
for i in range(0, self.depth): |
|||
m = self.blocks[i] |
|||
m.mlp.eval() |
|||
m.norm2.eval() |
|||
for param in m.mlp.parameters(): |
|||
param.requires_grad = False |
|||
for param in m.norm2.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
def init_weights(self, pretrained=None): |
|||
"""Initialize the weights in backbone. |
|||
Args: |
|||
pretrained (str, optional): Path to pre-trained weights. |
|||
Defaults to None. |
|||
""" |
|||
super().init_weights(pretrained, patch_padding=self.patch_padding) |
|||
|
|||
if pretrained is None: |
|||
def _init_weights(m): |
|||
if isinstance(m, nn.Linear): |
|||
trunc_normal_(m.weight, std=.02) |
|||
if isinstance(m, nn.Linear) and m.bias is not None: |
|||
nn.init.constant_(m.bias, 0) |
|||
elif isinstance(m, nn.LayerNorm): |
|||
nn.init.constant_(m.bias, 0) |
|||
nn.init.constant_(m.weight, 1.0) |
|||
|
|||
self.apply(_init_weights) |
|||
|
|||
def get_num_layers(self): |
|||
return len(self.blocks) |
|||
|
|||
@torch.jit.ignore |
|||
def no_weight_decay(self): |
|||
return {'pos_embed', 'cls_token'} |
|||
|
|||
def forward_features(self, x): |
|||
B, C, H, W = x.shape |
|||
x, (Hp, Wp) = self.patch_embed(x) |
|||
|
|||
if self.pos_embed is not None: |
|||
# fit for multiple GPU training |
|||
# since the first element for pos embed (sin-cos manner) is zero, it will cause no difference |
|||
x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] |
|||
|
|||
for blk in self.blocks: |
|||
if self.use_checkpoint: |
|||
x = checkpoint.checkpoint(blk, x) |
|||
else: |
|||
x = blk(x) |
|||
|
|||
x = self.last_norm(x) |
|||
|
|||
return x, Hp, Wp |
|||
|
|||
def forward(self, x1, x2): |
|||
import time |
|||
B, _, _, _ = x1.shape |
|||
x1, Hp, Wp = self.forward_features(x1) # B, N_vitpose, C |
|||
|
|||
with torch.no_grad(): |
|||
# start_time = time.time() |
|||
# self.sam_vit.eval() |
|||
x2 = self.sam_vit(x2) # B, N_sam, C |
|||
|
|||
# end_time = time.time() |
|||
# print('SAM-ViT forward time: {:.4f}秒'.format(end_time - start_time)) |
|||
|
|||
# x1 = x1 + self.cross_attn(x1, x2, x2) |
|||
|
|||
x1 = self.custom_attn_ffn(x1, x2) |
|||
|
|||
xp = x1.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() # B, C, Hp, Wp |
|||
return xp |
|||
|
|||
def train(self, mode=True): |
|||
"""Convert the model into training mode.""" |
|||
super().train(mode) |
|||
self._freeze_stages() |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
from thop import profile |
|||
from mmcv.runner import load_checkpoint |
|||
|
|||
# Prepare an image as the input |
|||
bs, c, h, w = 2, 3, 1024, 1024 |
|||
x1 = torch.randn(bs, c, 256, 192) |
|||
x2 = torch.randn(bs, c, h, w) |
|||
patch_size = 16 |
|||
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') |
|||
|
|||
# Build model |
|||
model = ViTSam(img_size=(256, 192), patch_size=16, embed_dim=768, depth=12, num_heads=12, ratio=1, |
|||
use_checkpoint=False, mlp_ratio=4, qkv_bias=True, drop_path_rate=0.3, |
|||
samvit_checkpoint='/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth') |
|||
|
|||
|
|||
if torch.cuda.is_available(): |
|||
x1 = x1.to(device) |
|||
x2 = x2.to(device) |
|||
model = model.to(device) |
|||
|
|||
with torch.no_grad(): |
|||
model.eval() |
|||
# Inference |
|||
outputs = model(x1, x2) |
|||
print(outputs.shape) |
|||
|
|||
# Compute FLOPs & Params |
|||
print('==============================') |
|||
model.eval() |
|||
flops, params = profile(model, inputs=(x1, x2), verbose=False) |
|||
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) |
|||
print('Params : {:.2f} M'.format(params / 1e6)) |
@ -0,0 +1,322 @@ |
|||
# Copyright (c) OpenMMLab. All rights reserved. |
|||
import warnings |
|||
import logging |
|||
|
|||
import mmcv |
|||
import numpy as np |
|||
from mmcv.image import imwrite |
|||
from mmcv.utils.misc import deprecated_api_warning |
|||
from mmcv.visualization.image import imshow |
|||
from mmcv_custom.checkpoint import load_checkpoint |
|||
|
|||
from mmpose.core import imshow_bboxes, imshow_keypoints |
|||
from .. import builder |
|||
from ..builder import POSENETS |
|||
from .base import BasePose |
|||
|
|||
try: |
|||
from mmcv.runner import auto_fp16 |
|||
except ImportError: |
|||
warnings.warn('auto_fp16 from mmpose will be deprecated from v0.15.0' |
|||
'Please install mmcv>=1.1.4') |
|||
from mmpose.core import auto_fp16 |
|||
|
|||
|
|||
@POSENETS.register_module() |
|||
class TopDownSelf(BasePose): |
|||
"""Top-down pose detectors. |
|||
|
|||
Args: |
|||
backbone (dict): Backbone modules to extract feature. |
|||
keypoint_head (dict): Keypoint head to process feature. |
|||
train_cfg (dict): Config for training. Default: None. |
|||
test_cfg (dict): Config for testing. Default: None. |
|||
pretrained (str): Path to the pretrained models. |
|||
loss_pose (None): Deprecated arguments. Please use |
|||
`loss_keypoint` for heads instead. |
|||
""" |
|||
|
|||
def __init__(self, |
|||
backbone, |
|||
neck=None, |
|||
keypoint_head=None, |
|||
train_cfg=None, |
|||
test_cfg=None, |
|||
pretrained=None, |
|||
loss_pose=None): |
|||
super().__init__() |
|||
self.fp16_enabled = False |
|||
|
|||
self.backbone = builder.build_backbone(backbone) |
|||
|
|||
self.train_cfg = train_cfg |
|||
self.test_cfg = test_cfg |
|||
|
|||
if neck is not None: |
|||
self.neck = builder.build_neck(neck) |
|||
|
|||
if keypoint_head is not None: |
|||
keypoint_head['train_cfg'] = train_cfg |
|||
keypoint_head['test_cfg'] = test_cfg |
|||
|
|||
if 'loss_keypoint' not in keypoint_head and loss_pose is not None: |
|||
warnings.warn( |
|||
'`loss_pose` for TopDown is deprecated, ' |
|||
'use `loss_keypoint` for heads instead. See ' |
|||
'https://github.com/open-mmlab/mmpose/pull/382' |
|||
' for more information.', DeprecationWarning) |
|||
keypoint_head['loss_keypoint'] = loss_pose |
|||
|
|||
self.keypoint_head = builder.build_head(keypoint_head) |
|||
|
|||
self.init_weights(pretrained=pretrained) |
|||
|
|||
@property |
|||
def with_neck(self): |
|||
"""Check if has neck.""" |
|||
return hasattr(self, 'neck') |
|||
|
|||
@property |
|||
def with_keypoint(self): |
|||
"""Check if has keypoint_head.""" |
|||
return hasattr(self, 'keypoint_head') |
|||
|
|||
def init_weights(self, pretrained=None): |
|||
"""Weight initialization for model.""" |
|||
self.backbone.init_weights(pretrained) |
|||
if self.with_neck: |
|||
self.neck.init_weights() |
|||
if self.with_keypoint: |
|||
self.keypoint_head.init_weights() |
|||
|
|||
@auto_fp16(apply_to=('img', 'sam_img', )) |
|||
def forward(self, |
|||
img, |
|||
sam_img, # 针对sam_encoder的输入 |
|||
target=None, |
|||
target_weight=None, |
|||
img_metas=None, |
|||
return_loss=True, |
|||
return_heatmap=False, |
|||
**kwargs): |
|||
"""Calls either forward_train or forward_test depending on whether |
|||
return_loss=True. Note this setting will change the expected inputs. |
|||
When `return_loss=True`, img and img_meta are single-nested (i.e. |
|||
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta |
|||
should be double nested (i.e. List[Tensor], List[List[dict]]), with |
|||
the outer list indicating test time augmentations. |
|||
|
|||
Note: |
|||
- batch_size: N |
|||
- num_keypoints: K |
|||
- num_img_channel: C (Default: 3) |
|||
- img height: imgH |
|||
- img width: imgW |
|||
- heatmaps height: H |
|||
- heatmaps weight: W |
|||
|
|||
Args: |
|||
img (torch.Tensor[NxCximgHximgW]): Input images. |
|||
target (torch.Tensor[NxKxHxW]): Target heatmaps. |
|||
target_weight (torch.Tensor[NxKx1]): Weights across |
|||
different joint types. |
|||
img_metas (list(dict)): Information about data augmentation |
|||
By default this includes: |
|||
|
|||
- "image_file: path to the image file |
|||
- "center": center of the bbox |
|||
- "scale": scale of the bbox |
|||
- "rotation": rotation of the bbox |
|||
- "bbox_score": score of bbox |
|||
return_loss (bool): Option to `return loss`. `return loss=True` |
|||
for training, `return loss=False` for validation & test. |
|||
return_heatmap (bool) : Option to return heatmap. |
|||
|
|||
Returns: |
|||
dict|tuple: if `return loss` is true, then return losses. \ |
|||
Otherwise, return predicted poses, boxes, image paths \ |
|||
and heatmaps. |
|||
""" |
|||
if return_loss: |
|||
# 可视化 img, sam_img cv可视化/PIL Image |
|||
# print(sam_img[0].shape) |
|||
# imshow(sam_img[0].cpu().numpy().transpose(1, 2, 0), wait_time=5000) |
|||
# 修改 |
|||
return self.forward_train(img, sam_img, target, target_weight, img_metas, |
|||
**kwargs) |
|||
# 修改 |
|||
return self.forward_test( |
|||
img, sam_img, img_metas, return_heatmap=return_heatmap, **kwargs) |
|||
|
|||
# 修改 |
|||
def forward_train(self, img, sam_img, target, target_weight, img_metas, **kwargs): |
|||
"""Defines the computation performed at every call when training.""" |
|||
# 修改 |
|||
output = self.backbone(img, sam_img) # B, C, Hp, Wp |
|||
if self.with_neck: |
|||
output = self.neck(output) |
|||
if self.with_keypoint: |
|||
output = self.keypoint_head(output) |
|||
|
|||
# if return loss |
|||
losses = dict() |
|||
if self.with_keypoint: |
|||
keypoint_losses = self.keypoint_head.get_loss( |
|||
output, target, target_weight) |
|||
losses.update(keypoint_losses) |
|||
keypoint_accuracy = self.keypoint_head.get_accuracy( |
|||
output, target, target_weight) |
|||
losses.update(keypoint_accuracy) |
|||
|
|||
return losses |
|||
|
|||
# 修改 |
|||
def forward_test(self, img, sam_img, img_metas, return_heatmap=False, **kwargs): |
|||
"""Defines the computation performed at every call when testing.""" |
|||
assert img.size(0) == len(img_metas) |
|||
batch_size, _, img_height, img_width = img.shape |
|||
if batch_size > 1: |
|||
assert 'bbox_id' in img_metas[0] |
|||
|
|||
result = {} |
|||
|
|||
# 修改 |
|||
features = self.backbone(img, sam_img) |
|||
if self.with_neck: |
|||
features = self.neck(features) |
|||
if self.with_keypoint: |
|||
output_heatmap = self.keypoint_head.inference_model( |
|||
features, flip_pairs=None) |
|||
|
|||
if self.test_cfg.get('flip_test', True): |
|||
img_flipped = img.flip(3) |
|||
# 修改 |
|||
sam_img_flipped = sam_img.flip(3) |
|||
features_flipped = self.backbone(img_flipped, sam_img_flipped) |
|||
if self.with_neck: |
|||
features_flipped = self.neck(features_flipped) |
|||
if self.with_keypoint: |
|||
output_flipped_heatmap = self.keypoint_head.inference_model( |
|||
features_flipped, img_metas[0]['flip_pairs']) |
|||
output_heatmap = (output_heatmap + |
|||
output_flipped_heatmap) * 0.5 |
|||
|
|||
if self.with_keypoint: |
|||
keypoint_result = self.keypoint_head.decode( |
|||
img_metas, output_heatmap, img_size=[img_width, img_height]) |
|||
result.update(keypoint_result) |
|||
|
|||
if not return_heatmap: |
|||
output_heatmap = None |
|||
|
|||
result['output_heatmap'] = output_heatmap |
|||
|
|||
return result |
|||
|
|||
# 修改 |
|||
def forward_dummy(self, img, sam_img): |
|||
"""Used for computing network FLOPs. |
|||
|
|||
See ``tools/get_flops.py``. |
|||
|
|||
Args: |
|||
img (torch.Tensor): Input image. |
|||
|
|||
Returns: |
|||
Tensor: Output heatmaps. |
|||
""" |
|||
output = self.backbone(img, sam_img) |
|||
if self.with_neck: |
|||
output = self.neck(output) |
|||
if self.with_keypoint: |
|||
output = self.keypoint_head(output) |
|||
return output |
|||
|
|||
@deprecated_api_warning({'pose_limb_color': 'pose_link_color'}, |
|||
cls_name='TopDown') |
|||
def show_result(self, |
|||
img, |
|||
result, |
|||
skeleton=None, |
|||
kpt_score_thr=0.3, |
|||
bbox_color='green', |
|||
pose_kpt_color=None, |
|||
pose_link_color=None, |
|||
text_color='white', |
|||
radius=4, |
|||
thickness=1, |
|||
font_scale=0.5, |
|||
bbox_thickness=1, |
|||
win_name='', |
|||
show=False, |
|||
show_keypoint_weight=False, |
|||
wait_time=0, |
|||
out_file=None): |
|||
"""Draw `result` over `img`. |
|||
|
|||
Args: |
|||
img (str or Tensor): The image to be displayed. |
|||
result (list[dict]): The results to draw over `img` |
|||
(bbox_result, pose_result). |
|||
skeleton (list[list]): The connection of keypoints. |
|||
skeleton is 0-based indexing. |
|||
kpt_score_thr (float, optional): Minimum score of keypoints |
|||
to be shown. Default: 0.3. |
|||
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines. |
|||
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. |
|||
If None, do not draw keypoints. |
|||
pose_link_color (np.array[Mx3]): Color of M links. |
|||
If None, do not draw links. |
|||
text_color (str or tuple or :obj:`Color`): Color of texts. |
|||
radius (int): Radius of circles. |
|||
thickness (int): Thickness of lines. |
|||
font_scale (float): Font scales of texts. |
|||
win_name (str): The window name. |
|||
show (bool): Whether to show the image. Default: False. |
|||
show_keypoint_weight (bool): Whether to change the transparency |
|||
using the predicted confidence scores of keypoints. |
|||
wait_time (int): Value of waitKey param. |
|||
Default: 0. |
|||
out_file (str or None): The filename to write the image. |
|||
Default: None. |
|||
|
|||
Returns: |
|||
Tensor: Visualized img, only if not `show` or `out_file`. |
|||
""" |
|||
img = mmcv.imread(img) |
|||
img = img.copy() |
|||
|
|||
bbox_result = [] |
|||
bbox_labels = [] |
|||
pose_result = [] |
|||
for res in result: |
|||
if 'bbox' in res: |
|||
bbox_result.append(res['bbox']) |
|||
bbox_labels.append(res.get('label', None)) |
|||
pose_result.append(res['keypoints']) |
|||
|
|||
if bbox_result: |
|||
bboxes = np.vstack(bbox_result) |
|||
# draw bounding boxes |
|||
imshow_bboxes( |
|||
img, |
|||
bboxes, |
|||
labels=bbox_labels, |
|||
colors=bbox_color, |
|||
text_color=text_color, |
|||
thickness=bbox_thickness, |
|||
font_scale=font_scale, |
|||
show=False) |
|||
|
|||
if pose_result: |
|||
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr, |
|||
pose_kpt_color, pose_link_color, radius, |
|||
thickness) |
|||
|
|||
if show: |
|||
imshow(img, win_name, wait_time) |
|||
|
|||
if out_file is not None: |
|||
imwrite(img, out_file) |
|||
|
|||
return img |
@ -0,0 +1,17 @@ |
|||
import torch |
|||
import numpy as np |
|||
|
|||
model_1 = torch.load('/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth') |
|||
model_2 = torch.load('/home/fhw/code/ViTPose/work_dirs/ViTSam_base_coco_256x192/best_AP_epoch_1.pth') |
|||
|
|||
param_1 = model_1['image_encoder.pos_embed'].numpy() |
|||
param_2 = model_2['state_dict']['backbone.sam_vit.pos_embed'].numpy() |
|||
|
|||
# for name, param in model_2.items(): |
|||
# print(name) |
|||
|
|||
# print(model_2['state_dict']['backbone.sam_vit.pos_embed']) |
|||
|
|||
is_equal = np.array_equal(param_1, param_2) |
|||
|
|||
print(is_equal) |
@ -0,0 +1,197 @@ |
|||
# Copyright (c) OpenMMLab. All rights reserved. |
|||
import argparse |
|||
import copy |
|||
import os |
|||
import os.path as osp |
|||
import time |
|||
import warnings |
|||
|
|||
import mmcv |
|||
import torch |
|||
from mmcv import Config, DictAction |
|||
from mmcv.runner import get_dist_info, init_dist, set_random_seed, load_checkpoint |
|||
from mmcv.utils import get_git_hash |
|||
|
|||
from mmpose import __version__ |
|||
from mmpose.apis import init_random_seed, train_model |
|||
from mmpose.datasets import build_dataset |
|||
from mmpose.models import build_posenet |
|||
from mmpose.utils import collect_env, get_root_logger, setup_multi_processes |
|||
import mmcv_custom |
|||
|
|||
def parse_args(): |
|||
parser = argparse.ArgumentParser(description='Train a pose model') |
|||
parser.add_argument('config', help='train config file path') |
|||
parser.add_argument('-c', '--checkpoint', help='checkpoint file', default='/root/autodl-tmp/code/ViTPose/checkpoints/vitpose/vitpose-b.pth') |
|||
parser.add_argument('--work-dir', help='the dir to save logs and models') |
|||
parser.add_argument( |
|||
'--resume-from', help='the checkpoint file to resume from') |
|||
parser.add_argument( |
|||
'--no-validate', |
|||
action='store_true', |
|||
help='whether not to evaluate the checkpoint during training') |
|||
group_gpus = parser.add_mutually_exclusive_group() |
|||
group_gpus.add_argument( |
|||
'--gpus', |
|||
type=int, |
|||
help='(Deprecated, please use --gpu-id) number of gpus to use ' |
|||
'(only applicable to non-distributed training)') |
|||
group_gpus.add_argument( |
|||
'--gpu-ids', |
|||
type=int, |
|||
nargs='+', |
|||
help='(Deprecated, please use --gpu-id) ids of gpus to use ' |
|||
'(only applicable to non-distributed training)') |
|||
group_gpus.add_argument( |
|||
'--gpu-id', |
|||
type=int, |
|||
default=0, |
|||
help='id of gpu to use ' |
|||
'(only applicable to non-distributed training)') |
|||
parser.add_argument('--seed', type=int, default=None, help='random seed') |
|||
parser.add_argument( |
|||
'--deterministic', |
|||
action='store_true', |
|||
help='whether to set deterministic options for CUDNN backend.') |
|||
parser.add_argument( |
|||
'--cfg-options', |
|||
nargs='+', |
|||
action=DictAction, |
|||
default={}, |
|||
help='override some settings in the used config, the key-value pair ' |
|||
'in xxx=yyy format will be merged into config file. For example, ' |
|||
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") |
|||
parser.add_argument( |
|||
'--launcher', |
|||
choices=['none', 'pytorch', 'slurm', 'mpi'], |
|||
default='none', |
|||
help='job launcher') |
|||
parser.add_argument('--local_rank', type=int, default=0) |
|||
parser.add_argument( |
|||
'--autoscale-lr', |
|||
action='store_true', |
|||
help='automatically scale lr with the number of gpus') |
|||
args = parser.parse_args() |
|||
if 'LOCAL_RANK' not in os.environ: |
|||
os.environ['LOCAL_RANK'] = str(args.local_rank) |
|||
|
|||
return args |
|||
|
|||
|
|||
def main(): |
|||
args = parse_args() |
|||
|
|||
cfg = Config.fromfile(args.config) |
|||
|
|||
if args.cfg_options is not None: |
|||
cfg.merge_from_dict(args.cfg_options) |
|||
|
|||
# set multi-process settings |
|||
setup_multi_processes(cfg) |
|||
|
|||
# set cudnn_benchmark |
|||
if cfg.get('cudnn_benchmark', False): |
|||
torch.backends.cudnn.benchmark = True |
|||
|
|||
# work_dir is determined in this priority: CLI > segment in file > filename |
|||
if args.work_dir is not None: |
|||
# update configs according to CLI args if args.work_dir is not None |
|||
cfg.work_dir = args.work_dir |
|||
elif cfg.get('work_dir', None) is None: |
|||
# use config filename as default work_dir if cfg.work_dir is None |
|||
cfg.work_dir = osp.join('./work_dirs', |
|||
osp.splitext(osp.basename(args.config))[0]) |
|||
if args.resume_from is not None: |
|||
cfg.resume_from = args.resume_from |
|||
if args.gpus is not None: |
|||
cfg.gpu_ids = range(1) |
|||
warnings.warn('`--gpus` is deprecated because we only support ' |
|||
'single GPU mode in non-distributed training. ' |
|||
'Use `gpus=1` now.') |
|||
if args.gpu_ids is not None: |
|||
cfg.gpu_ids = args.gpu_ids[0:1] |
|||
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' |
|||
'Because we only support single GPU mode in ' |
|||
'non-distributed training. Use the first GPU ' |
|||
'in `gpu_ids` now.') |
|||
if args.gpus is None and args.gpu_ids is None: |
|||
cfg.gpu_ids = [args.gpu_id] |
|||
|
|||
if args.autoscale_lr: |
|||
# apply the linear scaling rule (https://arxiv.org/abs/1706.02677) |
|||
cfg.optimizer['lr'] = cfg.optimizer['lr'] * len(cfg.gpu_ids) / 8 |
|||
|
|||
# init distributed env first, since logger depends on the dist info. |
|||
if args.launcher == 'none': |
|||
distributed = False |
|||
if len(cfg.gpu_ids) > 1: |
|||
warnings.warn( |
|||
f'We treat {cfg.gpu_ids} as gpu-ids, and reset to ' |
|||
f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in ' |
|||
'non-distribute training time.') |
|||
cfg.gpu_ids = cfg.gpu_ids[0:1] |
|||
else: |
|||
distributed = True |
|||
init_dist(args.launcher, **cfg.dist_params) |
|||
# re-set gpu_ids with distributed training mode |
|||
_, world_size = get_dist_info() |
|||
cfg.gpu_ids = range(world_size) |
|||
|
|||
# create work_dir |
|||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) |
|||
# init the logger before other steps |
|||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
|||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') |
|||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) |
|||
|
|||
# init the meta dict to record some important information such as |
|||
# environment info and seed, which will be logged |
|||
meta = dict() |
|||
# log env info |
|||
env_info_dict = collect_env() |
|||
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) |
|||
dash_line = '-' * 60 + '\n' |
|||
logger.info('Environment info:\n' + dash_line + env_info + '\n' + |
|||
dash_line) |
|||
meta['env_info'] = env_info |
|||
|
|||
# log some basic info |
|||
logger.info(f'Distributed training: {distributed}') |
|||
logger.info(f'Config:\n{cfg.pretty_text}') |
|||
|
|||
# set random seeds |
|||
seed = init_random_seed(args.seed) |
|||
logger.info(f'Set random seed to {seed}, ' |
|||
f'deterministic: {args.deterministic}') |
|||
set_random_seed(seed, deterministic=args.deterministic) |
|||
cfg.seed = seed |
|||
meta['seed'] = seed |
|||
|
|||
model = build_posenet(cfg.model) |
|||
load_checkpoint(model, args.checkpoint, map_location='cpu') |
|||
datasets = [build_dataset(cfg.data.train)] |
|||
|
|||
if len(cfg.workflow) == 2: |
|||
val_dataset = copy.deepcopy(cfg.data.val) |
|||
val_dataset.pipeline = cfg.data.train.pipeline |
|||
datasets.append(build_dataset(val_dataset)) |
|||
|
|||
if cfg.checkpoint_config is not None: |
|||
# save mmpose version, config file content |
|||
# checkpoints as meta data |
|||
cfg.checkpoint_config.meta = dict( |
|||
mmpose_version=__version__ + get_git_hash(digits=7), |
|||
config=cfg.pretty_text, |
|||
) |
|||
train_model( |
|||
model, |
|||
datasets, |
|||
cfg, |
|||
distributed=distributed, |
|||
validate=(not args.no_validate), |
|||
timestamp=timestamp, |
|||
meta=meta) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
main() |
File diff suppressed because it is too large
@ -0,0 +1 @@ |
|||
安装环境时的问题:由于setuptools版本过高,导致算法使用的安装方式已经被弃用,建议选择重新安装小于60的版本 |
Loading…
Reference in new issue