diff --git a/mmpose/models/backbones/vit_sam.py b/mmpose/models/backbones/vit_sam.py index 80c1b4f..7858034 100644 --- a/mmpose/models/backbones/vit_sam.py +++ b/mmpose/models/backbones/vit_sam.py @@ -205,10 +205,6 @@ class Cross_Attention(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 - - self.self_attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, - attn_drop=attn_drop, proj_drop=0.) self.linear_q = nn.Linear(dim, dim, bias=qkv_bias) self.linear_k = nn.Linear(dim, dim, bias=qkv_bias) @@ -262,7 +258,8 @@ class CustomAttentionFFN(nn.Module): self.ffn = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), - nn.Linear(dim * 4, dim) + nn.Linear(dim * 4, dim), + nn.DropPath(proj_drop) ) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) @@ -332,11 +329,22 @@ class ViTSam(BaseBackbone): for param in self.sam_vit.parameters(): param.requires_grad = False + # 交叉注意力 # self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ # qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) - self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ - qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) + # vit_token做自注意力后,再和sam_token做交叉注意力,得到的结果再经过FFN + # self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ + # qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) + + # 在sam_encoder后面加一层ffn + self.sam_ffn = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.GELU(), + nn.Linear(embed_dim * 4, embed_dim), + nn.DropPath(drop_rate) + ) + self.sam_norm = norm_layer(embed_dim) def _freeze_stages(self): """Freeze parameters.""" @@ -434,9 +442,11 @@ class ViTSam(BaseBackbone): # end_time = time.time() # print('SAM-ViT forward time: {:.4f}秒'.format(end_time - start_time)) - # x1 = x1 + self.cross_attn(x1, x2, x2) + x2 = self.sam_norm(x2 + self.sam_ffn(x2)) + + x1 = x1 + self.cross_attn(x1, x2, x2) - x1 = self.custom_attn_ffn(x1, x2) + # x1 = self.custom_attn_ffn(x1, x2) xp = x1.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() # B, C, Hp, Wp return xp