Browse Source

对SAM的相对位置编码进行插值处理

main
fhw 1 year ago
parent
commit
0eb96d2c47
  1. 91
      mmpose/models/backbones/sam_vit/image_encoder.py

91
mmpose/models/backbones/sam_vit/image_encoder.py

@ -42,14 +42,14 @@ class ImageEncoderViT(nn.Module):
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size self.patch_size = patch_size
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_patches = (img_size // patch_size) ** 2 # self.num_patches = (img_size // patch_size) ** 2
# self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
self.pos_embed: Optional[nn.Parameter] = None self.pos_embed: Optional[nn.Parameter] = None
self.checkpoint = checkpoint self.checkpoint = checkpoint
if use_abs_pos: if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size. # Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) torch.zeros(1, img_size[0] // patch_size, img_size[1] // patch_size, embed_dim)
) )
# ------------ Model parameters ------------ # ------------ Model parameters ------------
@ -72,13 +72,13 @@ class ImageEncoderViT(nn.Module):
act_layer = act_layer, act_layer = act_layer,
use_rel_pos = use_rel_pos, use_rel_pos = use_rel_pos,
window_size = window_size if i not in global_attn_indexes else 0, window_size = window_size if i not in global_attn_indexes else 0,
input_size = (img_size // patch_size, img_size // patch_size), input_size = (img_size[0] // patch_size, img_size[1] // patch_size),
) )
self.blocks.append(block) self.blocks.append(block)
self.load_pretrained() self.load_pretrained(global_attn_indexes)
def load_pretrained(self): def load_pretrained(self, global_attn_indexes=()):
if self.checkpoint is not None: if self.checkpoint is not None:
print('Loading SAM pretrained weight from : {}'.format(self.checkpoint)) print('Loading SAM pretrained weight from : {}'.format(self.checkpoint))
# checkpoint state dict # checkpoint state dict
@ -86,6 +86,9 @@ class ImageEncoderViT(nn.Module):
# model state dict # model state dict
model_state_dict = self.state_dict() model_state_dict = self.state_dict()
encoder_state_dict = {} encoder_state_dict = {}
patterns_h = ["blocks." + str(i) + ".attn.rel_pos_h" for i in global_attn_indexes]
patterns_w = ["blocks." + str(i) + ".attn.rel_pos_w" for i in global_attn_indexes]
# check # check
for k in list(checkpoint_state_dict.keys()): for k in list(checkpoint_state_dict.keys()):
if "image_encoder" in k and k[14:] in model_state_dict: if "image_encoder" in k and k[14:] in model_state_dict:
@ -93,12 +96,16 @@ class ImageEncoderViT(nn.Module):
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model == shape_checkpoint or "pos_embed" in k: if shape_model == shape_checkpoint or "pos_embed" in k:
encoder_state_dict[k[14:]] = checkpoint_state_dict[k] encoder_state_dict[k[14:]] = checkpoint_state_dict[k]
elif k[14:] in patterns_h or k[14:] in patterns_w:
encoder_state_dict[k[14:]] = checkpoint_state_dict[k]
else: else:
print("Shape unmatch: ", k) print("Shape unmatch: ", k)
# interpolate position embedding # interpolate position embedding
# interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size)))
interpolate_pos_embed(self, encoder_state_dict,) # interpolate_pos_embed(self, encoder_state_dict,)
# interpolate relative position embedding
interpolate_rel_pos_embed(encoder_state_dict, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size), global_attn_indexes)
# load the weight # load the weight
self.load_state_dict(encoder_state_dict, strict=False) self.load_state_dict(encoder_state_dict, strict=False)
@ -369,7 +376,7 @@ def add_decomposed_rel_pos(attn : torch.Tensor,
return attn return attn
def interpolate_pos_embed(model, checkpoint_model): def interpolate_pos_embed(model, checkpoint_model, new_size):
if 'pos_embed' in checkpoint_model: if 'pos_embed' in checkpoint_model:
# Pos embed from checkpoint # Pos embed from checkpoint
pos_embed_checkpoint = checkpoint_model['pos_embed'] pos_embed_checkpoint = checkpoint_model['pos_embed']
@ -386,25 +393,59 @@ def interpolate_pos_embed(model, checkpoint_model):
# height (== width) for the checkpoint position embedding # height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
new_size = int(num_patches ** 0.5) # new_size = int(num_patches ** 0.5)
# height (== width) for the new position embedding # height (== width) for the new position embedding
# class_token and dist_token are kept unchanged # class_token and dist_token are kept unchanged
if orig_size != new_size: if orig_size != new_size[0] or orig_size != new_size[1]:
print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size[0], new_size[1]))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated # only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(pos_tokens, pos_tokens = torch.nn.functional.interpolate(pos_tokens,
size=(new_size,new_size), # size=(new_size,new_size),
size=new_size,
mode='bicubic', mode='bicubic',
align_corners=False) align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) # new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size)
new_pos_embed = new_pos_embed.reshape(-1, new_size[0], new_size[1], embedding_size)
checkpoint_model['pos_embed'] = new_pos_embed checkpoint_model['pos_embed'] = new_pos_embed
def interpolate_rel_pos_embed(checkpoint_model, image_size, global_attn_indexes):
for i in global_attn_indexes:
if f'blocks.{i}.attn.rel_pos_h' in checkpoint_model:
# Pos embed from checkpoint
rel_pos_h_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_h']
rel_pos_w_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_w']
embedding_size = rel_pos_h_checkpoint.shape[-1]
orig_size = (rel_pos_h_checkpoint.shape[-2], rel_pos_w_checkpoint.shape[-2])
new_size = (2*image_size[0]-1, 2*image_size[1]-1)
# height (== width) for the new position embedding
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("- Relative Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1]))
# only the position tokens are interpolated
rel_pos_h_checkpoint = rel_pos_h_checkpoint.reshape(orig_size[0], embedding_size).unsqueeze(0).unsqueeze(0)
rel_pos_w_checkpoint = rel_pos_w_checkpoint.reshape(orig_size[1], embedding_size).unsqueeze(0).unsqueeze(0)
rel_pos_h_checkpoint = torch.nn.functional.interpolate(rel_pos_h_checkpoint,
size=(new_size[0], embedding_size),
mode='bicubic',
align_corners=False)
rel_pos_w_checkpoint = torch.nn.functional.interpolate(rel_pos_w_checkpoint,
size=(new_size[1], embedding_size),
mode='bicubic',
align_corners=False)
rel_pos_h_checkpoint = rel_pos_h_checkpoint.squeeze(0).squeeze(0).reshape(new_size[0], embedding_size)
rel_pos_w_checkpoint = rel_pos_w_checkpoint.squeeze(0).squeeze(0).reshape(new_size[1], embedding_size)
checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] = rel_pos_h_checkpoint
checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] = rel_pos_w_checkpoint
# ------------------------ Model Functions ------------------------ # ------------------------ Model Functions ------------------------
def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None): def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None):
@ -454,24 +495,24 @@ if __name__ == '__main__':
from thop import profile from thop import profile
# Prepare an image as the input # Prepare an image as the input
bs, c, h, w = 2, 3, 1024, 1024 bs, c, h, w = 2, 3, 256, 256
x = torch.randn(bs, c, h, w) x = torch.randn(bs, c, h, w)
patch_size = 16 patch_size = 16
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# Build model # Build model
model = build_vit_sam(model_name='vit_b', checkpoint="/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth") model = build_vit_sam(model_name='vit_b', checkpoint="/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth", img_size=(256, 256))
if torch.cuda.is_available(): if torch.cuda.is_available():
x = x.to(device) x = x.to(device)
model = model.to(device) model = model.to(device)
# Inference # # Inference
outputs = model(x) # outputs = model(x)
print(outputs.shape) # print(outputs.shape)
# Compute FLOPs & Params # # Compute FLOPs & Params
print('==============================') # print('==============================')
model.eval() # model.eval()
flops, params = profile(model, inputs=(x, ), verbose=False) # flops, params = profile(model, inputs=(x, ), verbose=False)
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) # print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
print('Params : {:.2f} M'.format(params / 1e6)) # print('Params : {:.2f} M'.format(params / 1e6))

Loading…
Cancel
Save