|
|
@ -42,14 +42,14 @@ class ImageEncoderViT(nn.Module): |
|
|
|
self.img_size = img_size |
|
|
|
self.patch_size = patch_size |
|
|
|
self.embed_dim = embed_dim |
|
|
|
self.num_patches = (img_size // patch_size) ** 2 |
|
|
|
# self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) |
|
|
|
# self.num_patches = (img_size // patch_size) ** 2 |
|
|
|
self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) |
|
|
|
self.pos_embed: Optional[nn.Parameter] = None |
|
|
|
self.checkpoint = checkpoint |
|
|
|
if use_abs_pos: |
|
|
|
# Initialize absolute positional embedding with pretrain image size. |
|
|
|
self.pos_embed = nn.Parameter( |
|
|
|
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) |
|
|
|
torch.zeros(1, img_size[0] // patch_size, img_size[1] // patch_size, embed_dim) |
|
|
|
) |
|
|
|
|
|
|
|
# ------------ Model parameters ------------ |
|
|
@ -72,13 +72,13 @@ class ImageEncoderViT(nn.Module): |
|
|
|
act_layer = act_layer, |
|
|
|
use_rel_pos = use_rel_pos, |
|
|
|
window_size = window_size if i not in global_attn_indexes else 0, |
|
|
|
input_size = (img_size // patch_size, img_size // patch_size), |
|
|
|
input_size = (img_size[0] // patch_size, img_size[1] // patch_size), |
|
|
|
) |
|
|
|
self.blocks.append(block) |
|
|
|
|
|
|
|
self.load_pretrained() |
|
|
|
self.load_pretrained(global_attn_indexes) |
|
|
|
|
|
|
|
def load_pretrained(self): |
|
|
|
def load_pretrained(self, global_attn_indexes=()): |
|
|
|
if self.checkpoint is not None: |
|
|
|
print('Loading SAM pretrained weight from : {}'.format(self.checkpoint)) |
|
|
|
# checkpoint state dict |
|
|
@ -86,6 +86,9 @@ class ImageEncoderViT(nn.Module): |
|
|
|
# model state dict |
|
|
|
model_state_dict = self.state_dict() |
|
|
|
encoder_state_dict = {} |
|
|
|
|
|
|
|
patterns_h = ["blocks." + str(i) + ".attn.rel_pos_h" for i in global_attn_indexes] |
|
|
|
patterns_w = ["blocks." + str(i) + ".attn.rel_pos_w" for i in global_attn_indexes] |
|
|
|
# check |
|
|
|
for k in list(checkpoint_state_dict.keys()): |
|
|
|
if "image_encoder" in k and k[14:] in model_state_dict: |
|
|
@ -93,12 +96,16 @@ class ImageEncoderViT(nn.Module): |
|
|
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape) |
|
|
|
if shape_model == shape_checkpoint or "pos_embed" in k: |
|
|
|
encoder_state_dict[k[14:]] = checkpoint_state_dict[k] |
|
|
|
elif k[14:] in patterns_h or k[14:] in patterns_w: |
|
|
|
encoder_state_dict[k[14:]] = checkpoint_state_dict[k] |
|
|
|
else: |
|
|
|
print("Shape unmatch: ", k) |
|
|
|
|
|
|
|
# interpolate position embedding |
|
|
|
# interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) |
|
|
|
interpolate_pos_embed(self, encoder_state_dict,) |
|
|
|
interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size))) |
|
|
|
# interpolate_pos_embed(self, encoder_state_dict,) |
|
|
|
# interpolate relative position embedding |
|
|
|
interpolate_rel_pos_embed(encoder_state_dict, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size), global_attn_indexes) |
|
|
|
|
|
|
|
# load the weight |
|
|
|
self.load_state_dict(encoder_state_dict, strict=False) |
|
|
@ -369,7 +376,7 @@ def add_decomposed_rel_pos(attn : torch.Tensor, |
|
|
|
|
|
|
|
return attn |
|
|
|
|
|
|
|
def interpolate_pos_embed(model, checkpoint_model): |
|
|
|
def interpolate_pos_embed(model, checkpoint_model, new_size): |
|
|
|
if 'pos_embed' in checkpoint_model: |
|
|
|
# Pos embed from checkpoint |
|
|
|
pos_embed_checkpoint = checkpoint_model['pos_embed'] |
|
|
@ -386,25 +393,59 @@ def interpolate_pos_embed(model, checkpoint_model): |
|
|
|
|
|
|
|
# height (== width) for the checkpoint position embedding |
|
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|
|
|
new_size = int(num_patches ** 0.5) |
|
|
|
# new_size = int(num_patches ** 0.5) |
|
|
|
|
|
|
|
# height (== width) for the new position embedding |
|
|
|
# class_token and dist_token are kept unchanged |
|
|
|
if orig_size != new_size: |
|
|
|
print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
|
|
|
if orig_size != new_size[0] or orig_size != new_size[1]: |
|
|
|
print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size[0], new_size[1])) |
|
|
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
|
# only the position tokens are interpolated |
|
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
|
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
|
|
|
pos_tokens = torch.nn.functional.interpolate(pos_tokens, |
|
|
|
size=(new_size,new_size), |
|
|
|
# size=(new_size,new_size), |
|
|
|
size=new_size, |
|
|
|
mode='bicubic', |
|
|
|
align_corners=False) |
|
|
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|
|
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
|
|
new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) |
|
|
|
# new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size) |
|
|
|
new_pos_embed = new_pos_embed.reshape(-1, new_size[0], new_size[1], embedding_size) |
|
|
|
checkpoint_model['pos_embed'] = new_pos_embed |
|
|
|
|
|
|
|
def interpolate_rel_pos_embed(checkpoint_model, image_size, global_attn_indexes): |
|
|
|
for i in global_attn_indexes: |
|
|
|
if f'blocks.{i}.attn.rel_pos_h' in checkpoint_model: |
|
|
|
# Pos embed from checkpoint |
|
|
|
rel_pos_h_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] |
|
|
|
rel_pos_w_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] |
|
|
|
embedding_size = rel_pos_h_checkpoint.shape[-1] |
|
|
|
|
|
|
|
orig_size = (rel_pos_h_checkpoint.shape[-2], rel_pos_w_checkpoint.shape[-2]) |
|
|
|
|
|
|
|
new_size = (2*image_size[0]-1, 2*image_size[1]-1) |
|
|
|
|
|
|
|
# height (== width) for the new position embedding |
|
|
|
# class_token and dist_token are kept unchanged |
|
|
|
if orig_size != new_size: |
|
|
|
print("- Relative Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1])) |
|
|
|
# only the position tokens are interpolated |
|
|
|
rel_pos_h_checkpoint = rel_pos_h_checkpoint.reshape(orig_size[0], embedding_size).unsqueeze(0).unsqueeze(0) |
|
|
|
rel_pos_w_checkpoint = rel_pos_w_checkpoint.reshape(orig_size[1], embedding_size).unsqueeze(0).unsqueeze(0) |
|
|
|
rel_pos_h_checkpoint = torch.nn.functional.interpolate(rel_pos_h_checkpoint, |
|
|
|
size=(new_size[0], embedding_size), |
|
|
|
mode='bicubic', |
|
|
|
align_corners=False) |
|
|
|
rel_pos_w_checkpoint = torch.nn.functional.interpolate(rel_pos_w_checkpoint, |
|
|
|
size=(new_size[1], embedding_size), |
|
|
|
mode='bicubic', |
|
|
|
align_corners=False) |
|
|
|
rel_pos_h_checkpoint = rel_pos_h_checkpoint.squeeze(0).squeeze(0).reshape(new_size[0], embedding_size) |
|
|
|
rel_pos_w_checkpoint = rel_pos_w_checkpoint.squeeze(0).squeeze(0).reshape(new_size[1], embedding_size) |
|
|
|
checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] = rel_pos_h_checkpoint |
|
|
|
checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] = rel_pos_w_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
# ------------------------ Model Functions ------------------------ |
|
|
|
def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None): |
|
|
@ -454,24 +495,24 @@ if __name__ == '__main__': |
|
|
|
from thop import profile |
|
|
|
|
|
|
|
# Prepare an image as the input |
|
|
|
bs, c, h, w = 2, 3, 1024, 1024 |
|
|
|
bs, c, h, w = 2, 3, 256, 256 |
|
|
|
x = torch.randn(bs, c, h, w) |
|
|
|
patch_size = 16 |
|
|
|
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
# Build model |
|
|
|
model = build_vit_sam(model_name='vit_b', checkpoint="/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth") |
|
|
|
model = build_vit_sam(model_name='vit_b', checkpoint="/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth", img_size=(256, 256)) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
x = x.to(device) |
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
# Inference |
|
|
|
outputs = model(x) |
|
|
|
print(outputs.shape) |
|
|
|
# # Inference |
|
|
|
# outputs = model(x) |
|
|
|
# print(outputs.shape) |
|
|
|
|
|
|
|
# Compute FLOPs & Params |
|
|
|
print('==============================') |
|
|
|
model.eval() |
|
|
|
flops, params = profile(model, inputs=(x, ), verbose=False) |
|
|
|
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) |
|
|
|
print('Params : {:.2f} M'.format(params / 1e6)) |
|
|
|
# # Compute FLOPs & Params |
|
|
|
# print('==============================') |
|
|
|
# model.eval() |
|
|
|
# flops, params = profile(model, inputs=(x, ), verbose=False) |
|
|
|
# print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2)) |
|
|
|
# print('Params : {:.2f} M'.format(params / 1e6)) |
|
|
|