From 4e8ffc9d8a0aa10ddcb885d6adca64bf397dbe38 Mon Sep 17 00:00:00 2001 From: fhw Date: Wed, 10 Jul 2024 18:53:15 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8sam=5Fencoder=E5=90=8E=E5=8A=A0?= =?UTF-8?q?=E4=BA=86ffn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mmpose/models/backbones/vit_sam.py | 28 +++++++++++++++++++--------- test.py | 17 ----------------- 2 files changed, 19 insertions(+), 26 deletions(-) delete mode 100644 test.py diff --git a/mmpose/models/backbones/vit_sam.py b/mmpose/models/backbones/vit_sam.py index 80c1b4f..7858034 100644 --- a/mmpose/models/backbones/vit_sam.py +++ b/mmpose/models/backbones/vit_sam.py @@ -205,10 +205,6 @@ class Cross_Attention(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 - - self.self_attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=0.) self.linear_q = nn.Linear(dim, dim, bias=qkv_bias) self.linear_k = nn.Linear(dim, dim, bias=qkv_bias) @@ -262,7 +258,8 @@ class CustomAttentionFFN(nn.Module): self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), - nn.Linear(dim * 4, dim) + nn.Linear(dim * 4, dim), + nn.DropPath(proj_drop) ) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) @@ -332,11 +329,22 @@ class ViTSam(BaseBackbone): for param in self.sam_vit.parameters(): param.requires_grad = False + # 交叉注意力 # self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ # qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) - self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ - qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) + # vit_token做自注意力后,再和sam_token做交叉注意力,得到的结果再经过FFN + # self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ + # qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) + + # 在sam_encoder后面加一层ffn + self.sam_ffn = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.GELU(), + nn.Linear(embed_dim * 4, embed_dim), + nn.DropPath(drop_rate) + ) + self.sam_norm = norm_layer(embed_dim) def _freeze_stages(self): """Freeze parameters.""" @@ -434,9 +442,11 @@ class ViTSam(BaseBackbone): # end_time = time.time() # print('SAM-ViT forward time: {:.4f}秒'.format(end_time - start_time)) - # x1 = x1 + self.cross_attn(x1, x2, x2) + x2 = self.sam_norm(x2 + self.sam_ffn(x2)) + + x1 = x1 + self.cross_attn(x1, x2, x2) - x1 = self.custom_attn_ffn(x1, x2) + # x1 = self.custom_attn_ffn(x1, x2) xp = x1.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() # B, C, Hp, Wp return xp diff --git a/test.py b/test.py deleted file mode 100644 index fe0a338..0000000 --- a/test.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import numpy as np - -model_1 = torch.load('/home/fhw/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth') -model_2 = torch.load('/home/fhw/code/ViTPose/work_dirs/ViTSam_base_coco_256x192/best_AP_epoch_1.pth') - -param_1 = model_1['image_encoder.pos_embed'].numpy() -param_2 = model_2['state_dict']['backbone.sam_vit.pos_embed'].numpy() - -# for name, param in model_2.items(): -# print(name) - -# print(model_2['state_dict']['backbone.sam_vit.pos_embed']) - -is_equal = np.array_equal(param_1, param_2) - -print(is_equal) \ No newline at end of file