diff --git a/mmpose/models/backbones/sam_vit/image_encoder.py b/mmpose/models/backbones/sam_vit/image_encoder.py index 10e4488..d2dd8d6 100644 --- a/mmpose/models/backbones/sam_vit/image_encoder.py +++ b/mmpose/models/backbones/sam_vit/image_encoder.py @@ -42,14 +42,14 @@ class ImageEncoderViT(nn.Module): self.img_size = img_size self.patch_size = patch_size self.embed_dim = embed_dim - self.num_patches = (img_size // patch_size) ** 2 - # self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + # self.num_patches = (img_size // patch_size) ** 2 + self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) self.pos_embed: Optional[nn.Parameter] = None self.checkpoint = checkpoint if use_abs_pos: # Initialize absolute positional embedding with pretrain image size. self.pos_embed = nn.Parameter( - torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + torch.zeros(1, img_size[0] // patch_size, img_size[1] // patch_size, embed_dim) ) # ------------ Model parameters ------------ @@ -72,13 +72,13 @@ class ImageEncoderViT(nn.Module): act_layer = act_layer, use_rel_pos = use_rel_pos, window_size = window_size if i not in global_attn_indexes else 0, - input_size = (img_size // patch_size, img_size // patch_size), + input_size = (img_size[0] // patch_size, img_size[1] // patch_size), ) self.blocks.append(block) - self.load_pretrained() + self.load_pretrained(global_attn_indexes) - def load_pretrained(self): + def load_pretrained(self, global_attn_indexes=()): if self.checkpoint is not None: print('Loading SAM pretrained weight from : {}'.format(self.checkpoint)) # checkpoint state dict @@ -86,6 +86,9 @@ class ImageEncoderViT(nn.Module): # model state dict model_state_dict = self.state_dict() encoder_state_dict = {} + + patterns_h = ["blocks." + str(i) + ".attn.rel_pos_h" for i in global_attn_indexes] + patterns_w = ["blocks." + str(i) + ".attn.rel_pos_w" for i in global_attn_indexes] # check for k in list(checkpoint_state_dict.keys()): if "image_encoder" in k and k[14:] in model_state_dict: @@ -93,12 +96,16 @@ class ImageEncoderViT(nn.Module): shape_checkpoint = tuple(checkpoint_state_dict[k].shape) if shape_model == shape_checkpoint or "pos_embed" in k: encoder_state_dict[k[14:]] = checkpoint_state_dict[k] + elif k[14:] in patterns_h or k[14:] in patterns_w: + encoder_state_dict[k[14:]] = checkpoint_state_dict[k] else: print("Shape unmatch: ", k) # interpolate position embedding - # interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) - interpolate_pos_embed(self, encoder_state_dict,) + interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) + # interpolate_pos_embed(self, encoder_state_dict,) + # interpolate relative position embedding + interpolate_rel_pos_embed(encoder_state_dict, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size), global_attn_indexes) # load the weight self.load_state_dict(encoder_state_dict, strict=False) @@ -369,7 +376,7 @@ def add_decomposed_rel_pos(attn : torch.Tensor, return attn -def interpolate_pos_embed(model, checkpoint_model): +def interpolate_pos_embed(model, checkpoint_model, new_size): if 'pos_embed' in checkpoint_model: # Pos embed from checkpoint pos_embed_checkpoint = checkpoint_model['pos_embed'] @@ -386,25 +393,59 @@ def interpolate_pos_embed(model, checkpoint_model): # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) - new_size = int(num_patches ** 0.5) + # new_size = int(num_patches ** 0.5) # height (== width) for the new position embedding # class_token and dist_token are kept unchanged - if orig_size != new_size: - print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + if orig_size != new_size[0] or orig_size != new_size[1]: + print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size[0], new_size[1])) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, - size=(new_size,new_size), + # size=(new_size,new_size), + size=new_size, mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) + # new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) + new_pos_embed = new_pos_embed.reshape(-1, new_size[0], new_size[1], embedding_size) checkpoint_model['pos_embed'] = new_pos_embed +def interpolate_rel_pos_embed(checkpoint_model, image_size, global_attn_indexes): + for i in global_attn_indexes: + if f'blocks.{i}.attn.rel_pos_h' in checkpoint_model: + # Pos embed from checkpoint + rel_pos_h_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] + rel_pos_w_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] + embedding_size = rel_pos_h_checkpoint.shape[-1] + + orig_size = (rel_pos_h_checkpoint.shape[-2], rel_pos_w_checkpoint.shape[-2]) + + new_size = (2*image_size[0]-1, 2*image_size[1]-1) + + # height (== width) for the new position embedding + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("- Relative Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1])) + # only the position tokens are interpolated + rel_pos_h_checkpoint = rel_pos_h_checkpoint.reshape(orig_size[0], embedding_size).unsqueeze(0).unsqueeze(0) + rel_pos_w_checkpoint = rel_pos_w_checkpoint.reshape(orig_size[1], embedding_size).unsqueeze(0).unsqueeze(0) + rel_pos_h_checkpoint = torch.nn.functional.interpolate(rel_pos_h_checkpoint, + size=(new_size[0], embedding_size), + mode='bicubic', + align_corners=False) + rel_pos_w_checkpoint = torch.nn.functional.interpolate(rel_pos_w_checkpoint, + size=(new_size[1], embedding_size), + mode='bicubic', + align_corners=False) + rel_pos_h_checkpoint = rel_pos_h_checkpoint.squeeze(0).squeeze(0).reshape(new_size[0], embedding_size) + rel_pos_w_checkpoint = rel_pos_w_checkpoint.squeeze(0).squeeze(0).reshape(new_size[1], embedding_size) + checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] = rel_pos_h_checkpoint + checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] = rel_pos_w_checkpoint + # ------------------------ Model Functions ------------------------ def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None): @@ -454,24 +495,24 @@ if __name__ == '__main__': from thop import profile # Prepare an image as the input - bs, c, h, w = 2, 3, 1024, 1024 + bs, c, h, w = 2, 3, 256, 256 x = torch.randn(bs, c, h, w) patch_size = 16 device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') # Build model - model = build_vit_sam(model_name='vit_b', checkpoint="/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth") + model = build_vit_sam(model_name='vit_b', checkpoint="/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth", img_size=(256, 256)) if torch.cuda.is_available(): x = x.to(device) model = model.to(device) - # Inference - outputs = model(x) - print(outputs.shape) + # # Inference + # outputs = model(x) + # print(outputs.shape) - # Compute FLOPs & Params - print('==============================') - model.eval() - flops, params = profile(model, inputs=(x, ), verbose=False) - print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) - print('Params : {:.2f} M'.format(params / 1e6)) + # # Compute FLOPs & Params + # print('==============================') + # model.eval() + # flops, params = profile(model, inputs=(x, ), verbose=False) + # print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) + # print('Params : {:.2f} M'.format(params / 1e6))