Browse Source

在sam_encoder后加了ffn

main
fhw 1 year ago
parent
commit
4e8ffc9d8a
  1. 28
      mmpose/models/backbones/vit_sam.py
  2. 17
      test.py

28
mmpose/models/backbones/vit_sam.py

@ -205,10 +205,6 @@ class Cross_Attention(nn.Module):
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.self_attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=0.)
self.linear_q = nn.Linear(dim, dim, bias=qkv_bias)
self.linear_k = nn.Linear(dim, dim, bias=qkv_bias)
@ -262,7 +258,8 @@ class CustomAttentionFFN(nn.Module):
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, dim),
nn.DropPath(proj_drop)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
@ -332,11 +329,22 @@ class ViTSam(BaseBackbone):
for param in self.sam_vit.parameters():
param.requires_grad = False
# 交叉注意力
# self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \
# qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate)
self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \
qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate)
# vit_token做自注意力后,再和sam_token做交叉注意力,得到的结果再经过FFN
# self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \
# qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate)
# 在sam_encoder后面加一层ffn
self.sam_ffn = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.DropPath(drop_rate)
)
self.sam_norm = norm_layer(embed_dim)
def _freeze_stages(self):
"""Freeze parameters."""
@ -434,9 +442,11 @@ class ViTSam(BaseBackbone):
# end_time = time.time()
# print('SAM-ViT forward time: {:.4f}秒'.format(end_time - start_time))
# x1 = x1 + self.cross_attn(x1, x2, x2)
x2 = self.sam_norm(x2 + self.sam_ffn(x2))
x1 = x1 + self.cross_attn(x1, x2, x2)
x1 = self.custom_attn_ffn(x1, x2)
# x1 = self.custom_attn_ffn(x1, x2)
xp = x1.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() # B, C, Hp, Wp
return xp

17
test.py

@ -1,17 +0,0 @@
import torch
import numpy as np
model_1 = torch.load('/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth')
model_2 = torch.load('/home/fhw/code/ViTPose/work_dirs/ViTSam_base_coco_256x192/best_AP_epoch_1.pth')
param_1 = model_1['image_encoder.pos_embed'].numpy()
param_2 = model_2['state_dict']['backbone.sam_vit.pos_embed'].numpy()
# for name, param in model_2.items():
# print(name)
# print(model_2['state_dict']['backbone.sam_vit.pos_embed'])
is_equal = np.array_equal(param_1, param_2)
print(is_equal)
Loading…
Cancel
Save