|
|
@ -206,10 +206,6 @@ class Cross_Attention(nn.Module): |
|
|
|
head_dim = dim // num_heads |
|
|
|
self.scale = qk_scale or head_dim ** -0.5 |
|
|
|
|
|
|
|
self.self_attn = Attention( |
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|
|
|
attn_drop=attn_drop, proj_drop=0.) |
|
|
|
|
|
|
|
self.linear_q = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
|
self.linear_k = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
|
self.linear_v = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
@ -262,7 +258,8 @@ class CustomAttentionFFN(nn.Module): |
|
|
|
self.ffn = nn.Sequential( |
|
|
|
nn.Linear(dim, dim * 4), |
|
|
|
nn.GELU(), |
|
|
|
nn.Linear(dim * 4, dim) |
|
|
|
nn.Linear(dim * 4, dim), |
|
|
|
nn.DropPath(proj_drop) |
|
|
|
) |
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
@ -332,11 +329,22 @@ class ViTSam(BaseBackbone): |
|
|
|
for param in self.sam_vit.parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
# 交叉注意力 |
|
|
|
# self.cross_attn = Cross_Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|
|
|
# qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) |
|
|
|
|
|
|
|
self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|
|
|
qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) |
|
|
|
# vit_token做自注意力后,再和sam_token做交叉注意力,得到的结果再经过FFN |
|
|
|
# self.custom_attn_ffn = CustomAttentionFFN(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, \ |
|
|
|
# qk_scale=qk_scale, attn_drop=attn_drop_rate, proj_drop=drop_rate) |
|
|
|
|
|
|
|
# 在sam_encoder后面加一层ffn |
|
|
|
self.sam_ffn = nn.Sequential( |
|
|
|
nn.Linear(embed_dim, embed_dim * 4), |
|
|
|
nn.GELU(), |
|
|
|
nn.Linear(embed_dim * 4, embed_dim), |
|
|
|
nn.DropPath(drop_rate) |
|
|
|
) |
|
|
|
self.sam_norm = norm_layer(embed_dim) |
|
|
|
|
|
|
|
def _freeze_stages(self): |
|
|
|
"""Freeze parameters.""" |
|
|
@ -434,9 +442,11 @@ class ViTSam(BaseBackbone): |
|
|
|
# end_time = time.time() |
|
|
|
# print('SAM-ViT forward time: {:.4f}秒'.format(end_time - start_time)) |
|
|
|
|
|
|
|
# x1 = x1 + self.cross_attn(x1, x2, x2) |
|
|
|
x2 = self.sam_norm(x2 + self.sam_ffn(x2)) |
|
|
|
|
|
|
|
x1 = x1 + self.cross_attn(x1, x2, x2) |
|
|
|
|
|
|
|
x1 = self.custom_attn_ffn(x1, x2) |
|
|
|
# x1 = self.custom_attn_ffn(x1, x2) |
|
|
|
|
|
|
|
xp = x1.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() # B, C, Hp, Wp |
|
|
|
return xp |
|
|
|