You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
518 lines
22 KiB
518 lines
22 KiB
# --------------------------------------------------------------------
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# --------------------------------------------------------------------
|
|
|
|
from typing import Optional, Tuple, Type
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn
|
|
import torch.nn.functional as F
|
|
|
|
from functools import partial
|
|
|
|
|
|
# ---------------------- Vision Transformer of Segment-Anything ----------------------
|
|
class ImageEncoderViT(nn.Module):
|
|
"""
|
|
We remove the neck which used in the Segment-Anything.
|
|
"""
|
|
def __init__(self,
|
|
img_size : int = 1024,
|
|
patch_size : int = 16,
|
|
in_chans : int = 3,
|
|
embed_dim : int = 768,
|
|
depth : int = 12,
|
|
num_heads : int = 12,
|
|
mlp_ratio : float = 4.0,
|
|
qkv_bias : bool = True,
|
|
norm_layer : Type[nn.Module] = nn.LayerNorm,
|
|
act_layer : Type[nn.Module] = nn.GELU,
|
|
use_abs_pos : bool = True,
|
|
use_rel_pos : bool = True,
|
|
window_size : int = 0,
|
|
global_attn_indexes : Tuple[int, ...] = (),
|
|
checkpoint = None
|
|
) -> None:
|
|
super().__init__()
|
|
self.img_size = img_size
|
|
self.patch_size = patch_size
|
|
self.embed_dim = embed_dim
|
|
# self.num_patches = (img_size // patch_size) ** 2
|
|
self.num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
|
|
self.pos_embed: Optional[nn.Parameter] = None
|
|
self.checkpoint = checkpoint
|
|
if use_abs_pos:
|
|
# Initialize absolute positional embedding with pretrain image size.
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, img_size[0] // patch_size, img_size[1] // patch_size, embed_dim)
|
|
)
|
|
|
|
# ------------ Model parameters ------------
|
|
## Patch embedding layer
|
|
self.patch_embed = PatchEmbed(
|
|
kernel_size=(patch_size, patch_size),
|
|
stride=(patch_size, patch_size),
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
)
|
|
|
|
## ViT blocks
|
|
self.blocks = nn.ModuleList()
|
|
for i in range(depth):
|
|
block = Block(dim = embed_dim,
|
|
num_heads = num_heads,
|
|
mlp_ratio = mlp_ratio,
|
|
qkv_bias = qkv_bias,
|
|
norm_layer = norm_layer,
|
|
act_layer = act_layer,
|
|
use_rel_pos = use_rel_pos,
|
|
window_size = window_size if i not in global_attn_indexes else 0,
|
|
input_size = (img_size[0] // patch_size, img_size[1] // patch_size),
|
|
)
|
|
self.blocks.append(block)
|
|
|
|
self.load_pretrained(global_attn_indexes)
|
|
|
|
def load_pretrained(self, global_attn_indexes=()):
|
|
if self.checkpoint is not None:
|
|
print('Loading SAM pretrained weight from : {}'.format(self.checkpoint))
|
|
# checkpoint state dict
|
|
checkpoint_state_dict = torch.load(self.checkpoint, map_location="cpu")
|
|
# model state dict
|
|
model_state_dict = self.state_dict()
|
|
encoder_state_dict = {}
|
|
|
|
patterns_h = ["blocks." + str(i) + ".attn.rel_pos_h" for i in global_attn_indexes]
|
|
patterns_w = ["blocks." + str(i) + ".attn.rel_pos_w" for i in global_attn_indexes]
|
|
# check
|
|
for k in list(checkpoint_state_dict.keys()):
|
|
if "image_encoder" in k and k[14:] in model_state_dict:
|
|
shape_model = tuple(model_state_dict[k[14:]].shape)
|
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
|
if shape_model == shape_checkpoint or "pos_embed" in k:
|
|
encoder_state_dict[k[14:]] = checkpoint_state_dict[k]
|
|
elif k[14:] in patterns_h or k[14:] in patterns_w:
|
|
encoder_state_dict[k[14:]] = checkpoint_state_dict[k]
|
|
else:
|
|
print("Shape unmatch: ", k)
|
|
|
|
# interpolate position embedding
|
|
interpolate_pos_embed(self, encoder_state_dict, ((self.img_size[0] // self.patch_size), (self.img_size[1] // self.patch_size)))
|
|
# interpolate_pos_embed(self, encoder_state_dict,)
|
|
# interpolate relative position embedding
|
|
interpolate_rel_pos_embed(encoder_state_dict, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size), global_attn_indexes)
|
|
|
|
# load the weight
|
|
self.load_state_dict(encoder_state_dict, strict=False)
|
|
else:
|
|
print('No SAM pretrained.')
|
|
|
|
# @torch.no_grad()
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# with torch.no_grad():
|
|
|
|
x = self.patch_embed(x)
|
|
if self.pos_embed is not None:
|
|
x = x + self.pos_embed
|
|
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
|
|
# [B, H, W, C] -> [B, N, C]
|
|
return x.flatten(1, 2)
|
|
|
|
|
|
# ---------------------- Model modules ----------------------
|
|
class MLPBlock(nn.Module):
|
|
def __init__(self,
|
|
embedding_dim: int,
|
|
mlp_dim: int,
|
|
act: Type[nn.Module] = nn.GELU,
|
|
) -> None:
|
|
super().__init__()
|
|
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
|
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
|
self.act = act()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.lin2(self.act(self.lin1(x)))
|
|
|
|
class LayerNorm2d(nn.Module):
|
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
self.eps = eps
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
u = x.mean(1, keepdim=True)
|
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
|
|
return x
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self,
|
|
dim : int,
|
|
num_heads : int,
|
|
mlp_ratio : float = 4.0,
|
|
qkv_bias : bool = True,
|
|
norm_layer : Type[nn.Module] = nn.LayerNorm,
|
|
act_layer : Type[nn.Module] = nn.GELU,
|
|
use_rel_pos : bool = False,
|
|
window_size : int = 0,
|
|
input_size : Optional[Tuple[int, int]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
# -------------- Basic parameters --------------
|
|
self.window_size = window_size
|
|
# -------------- Model parameters --------------
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = Attention(dim = dim,
|
|
num_heads = num_heads,
|
|
qkv_bias = qkv_bias,
|
|
use_rel_pos = use_rel_pos,
|
|
input_size = input_size if window_size == 0 else (window_size, window_size),
|
|
)
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
# Window partition
|
|
if self.window_size > 0:
|
|
H, W = x.shape[1], x.shape[2]
|
|
x, pad_hw = window_partition(x, self.window_size)
|
|
|
|
x = self.attn(x)
|
|
# Reverse window partition
|
|
if self.window_size > 0:
|
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
|
|
|
x = shortcut + x
|
|
x = x + self.mlp(self.norm2(x))
|
|
|
|
return x
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self,
|
|
dim: int,
|
|
num_heads: int = 8,
|
|
qkv_bias: bool = True,
|
|
use_rel_pos: bool = False,
|
|
input_size: Optional[Tuple[int, int]] = None,
|
|
) -> None:
|
|
super().__init__()
|
|
# -------------- Basic parameters --------------
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = head_dim**-0.5
|
|
self.use_rel_pos = use_rel_pos
|
|
if self.use_rel_pos:
|
|
assert (
|
|
input_size is not None
|
|
), "Input size must be provided if using relative positional encoding."
|
|
# initialize relative positional embeddings
|
|
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
|
|
# -------------- Model parameters --------------
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, H, W, _ = x.shape
|
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
# q, k, v with shape (B * nHead, H * W, C)
|
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
|
|
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
|
|
if self.use_rel_pos:
|
|
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
|
x = self.proj(x)
|
|
|
|
return x
|
|
|
|
class PatchEmbed(nn.Module):
|
|
def __init__(self,
|
|
kernel_size : Tuple[int, int] = (16, 16),
|
|
stride : Tuple[int, int] = (16, 16),
|
|
padding : Tuple[int, int] = (0, 0),
|
|
in_chans : int = 3,
|
|
embed_dim : int = 768,
|
|
) -> None:
|
|
super().__init__()
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.proj(x)
|
|
# [B, C, H, W] -> [B, H, W, C]
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
return x
|
|
|
|
|
|
# ---------------------- Model functions ----------------------
|
|
def window_partition(x: torch.Tensor,
|
|
window_size: int,
|
|
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
|
"""
|
|
Partition into non-overlapping windows with padding if needed.
|
|
Args:
|
|
x (tensor): input tokens with [B, H, W, C].
|
|
window_size (int): window size.
|
|
|
|
Returns:
|
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
(Hp, Wp): padded height and width before partition
|
|
"""
|
|
B, H, W, C = x.shape
|
|
|
|
pad_h = (window_size - H % window_size) % window_size
|
|
pad_w = (window_size - W % window_size) % window_size
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
|
Hp, Wp = H + pad_h, W + pad_w
|
|
|
|
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
|
|
return windows, (Hp, Wp)
|
|
|
|
def window_unpartition(windows: torch.Tensor,
|
|
window_size: int,
|
|
pad_hw: Tuple[int, int],
|
|
hw: Tuple[int, int],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Window unpartition into original sequences and removing padding.
|
|
Args:
|
|
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
window_size (int): window size.
|
|
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
hw (Tuple): original height and width (H, W) before padding.
|
|
|
|
Returns:
|
|
x: unpartitioned sequences with [B, H, W, C].
|
|
"""
|
|
Hp, Wp = pad_hw
|
|
H, W = hw
|
|
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
|
|
|
if Hp > H or Wp > W:
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
return x
|
|
|
|
def get_rel_pos(q_size: int,
|
|
k_size: int,
|
|
rel_pos: torch.Tensor,
|
|
)-> torch.Tensor:
|
|
"""
|
|
Get relative positional embeddings according to the relative positions of
|
|
query and key sizes.
|
|
Args:
|
|
q_size (int): size of query q.
|
|
k_size (int): size of key k.
|
|
rel_pos (Tensor): relative position embeddings (L, C).
|
|
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos if needed.
|
|
if rel_pos.shape[0] != max_rel_dist:
|
|
# Interpolate rel pos.
|
|
rel_pos_resized = F.interpolate(
|
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
size=max_rel_dist,
|
|
mode="linear",
|
|
)
|
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
else:
|
|
rel_pos_resized = rel_pos
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
|
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
|
|
return rel_pos_resized[relative_coords.long()]
|
|
|
|
def add_decomposed_rel_pos(attn : torch.Tensor,
|
|
q : torch.Tensor,
|
|
rel_pos_h : torch.Tensor,
|
|
rel_pos_w : torch.Tensor,
|
|
q_size : Tuple[int, int],
|
|
k_size : Tuple[int, int],
|
|
) -> torch.Tensor:
|
|
q_h, q_w = q_size
|
|
k_h, k_w = k_size
|
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
|
|
B, _, dim = q.shape
|
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
|
|
attn = (
|
|
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
|
).view(B, q_h * q_w, k_h * k_w)
|
|
|
|
return attn
|
|
|
|
def interpolate_pos_embed(model, checkpoint_model, new_size):
|
|
if 'pos_embed' in checkpoint_model:
|
|
# Pos embed from checkpoint
|
|
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
|
embedding_size = pos_embed_checkpoint.shape[-1]
|
|
# Pos embed from model
|
|
pos_embed_model = model.pos_embed
|
|
num_patches = model.num_patches
|
|
# [B, H, W, C] -> [B, N, C]
|
|
pos_embed_checkpoint = pos_embed_checkpoint.flatten(1, 2)
|
|
pos_embed_model = pos_embed_model.flatten(1, 2)
|
|
|
|
orig_num_postions = pos_embed_model.shape[-2]
|
|
num_extra_tokens = orig_num_postions - num_patches
|
|
|
|
# height (== width) for the checkpoint position embedding
|
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
|
# new_size = int(num_patches ** 0.5)
|
|
|
|
# height (== width) for the new position embedding
|
|
# class_token and dist_token are kept unchanged
|
|
if orig_size != new_size[0] or orig_size != new_size[1]:
|
|
print("- Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size[0], new_size[1]))
|
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
|
# only the position tokens are interpolated
|
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
|
pos_tokens = torch.nn.functional.interpolate(pos_tokens,
|
|
# size=(new_size,new_size),
|
|
size=new_size,
|
|
mode='bicubic',
|
|
align_corners=False)
|
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
|
# new_pos_embed = new_pos_embed.reshape(-1, int(orig_num_postions ** 0.5), int(orig_num_postions ** 0.5), embedding_size)
|
|
new_pos_embed = new_pos_embed.reshape(-1, new_size[0], new_size[1], embedding_size)
|
|
checkpoint_model['pos_embed'] = new_pos_embed
|
|
|
|
def interpolate_rel_pos_embed(checkpoint_model, image_size, global_attn_indexes):
|
|
for i in global_attn_indexes:
|
|
if f'blocks.{i}.attn.rel_pos_h' in checkpoint_model:
|
|
# Pos embed from checkpoint
|
|
rel_pos_h_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_h']
|
|
rel_pos_w_checkpoint = checkpoint_model[f'blocks.{i}.attn.rel_pos_w']
|
|
embedding_size = rel_pos_h_checkpoint.shape[-1]
|
|
|
|
orig_size = (rel_pos_h_checkpoint.shape[-2], rel_pos_w_checkpoint.shape[-2])
|
|
|
|
new_size = (2*image_size[0]-1, 2*image_size[1]-1)
|
|
|
|
# height (== width) for the new position embedding
|
|
# class_token and dist_token are kept unchanged
|
|
if orig_size != new_size:
|
|
print("- Relative Position interpolate from %dx%d to %dx%d" % (orig_size[0], orig_size[1], new_size[0], new_size[1]))
|
|
# only the position tokens are interpolated
|
|
rel_pos_h_checkpoint = rel_pos_h_checkpoint.reshape(orig_size[0], embedding_size).unsqueeze(0).unsqueeze(0)
|
|
rel_pos_w_checkpoint = rel_pos_w_checkpoint.reshape(orig_size[1], embedding_size).unsqueeze(0).unsqueeze(0)
|
|
rel_pos_h_checkpoint = torch.nn.functional.interpolate(rel_pos_h_checkpoint,
|
|
size=(new_size[0], embedding_size),
|
|
mode='bicubic',
|
|
align_corners=False)
|
|
rel_pos_w_checkpoint = torch.nn.functional.interpolate(rel_pos_w_checkpoint,
|
|
size=(new_size[1], embedding_size),
|
|
mode='bicubic',
|
|
align_corners=False)
|
|
rel_pos_h_checkpoint = rel_pos_h_checkpoint.squeeze(0).squeeze(0).reshape(new_size[0], embedding_size)
|
|
rel_pos_w_checkpoint = rel_pos_w_checkpoint.squeeze(0).squeeze(0).reshape(new_size[1], embedding_size)
|
|
checkpoint_model[f'blocks.{i}.attn.rel_pos_h'] = rel_pos_h_checkpoint
|
|
checkpoint_model[f'blocks.{i}.attn.rel_pos_w'] = rel_pos_w_checkpoint
|
|
|
|
|
|
# ------------------------ Model Functions ------------------------
|
|
def build_vit_sam(model_name="vit_h", img_size=1024, patch_size=16, img_dim=3, checkpoint=None):
|
|
if model_name == "vit_b":
|
|
return ImageEncoderViT(img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_chans=img_dim,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
mlp_ratio=4.0,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
global_attn_indexes=[2, 5, 8, 11],
|
|
window_size=14,
|
|
checkpoint=checkpoint,
|
|
)
|
|
if model_name == "vit_l":
|
|
return ImageEncoderViT(img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_chans=img_dim,
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
mlp_ratio=4.0,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
global_attn_indexes=[5, 11, 17, 23],
|
|
window_size=14,
|
|
checkpoint=checkpoint,
|
|
)
|
|
if model_name == "vit_h":
|
|
return ImageEncoderViT(img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_chans=img_dim,
|
|
embed_dim=1280,
|
|
depth=32,
|
|
num_heads=16,
|
|
mlp_ratio=4.0,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
global_attn_indexes=[7, 15, 23, 31],
|
|
window_size=14,
|
|
checkpoint=checkpoint,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import torch
|
|
from thop import profile
|
|
|
|
# Prepare an image as the input
|
|
bs, c, h, w = 2, 3, 256, 256
|
|
x = torch.randn(bs, c, h, w)
|
|
patch_size = 16
|
|
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Build model
|
|
model = build_vit_sam(model_name='vit_b', checkpoint="/root/autodl-tmp/code/ViTPose/checkpoints/sam/sam_vit_b_01ec64.pth", img_size=(256, 256))
|
|
if torch.cuda.is_available():
|
|
x = x.to(device)
|
|
model = model.to(device)
|
|
|
|
# # Inference
|
|
# outputs = model(x)
|
|
# print(outputs.shape)
|
|
|
|
# # Compute FLOPs & Params
|
|
# print('==============================')
|
|
# model.eval()
|
|
# flops, params = profile(model, inputs=(x, ), verbose=False)
|
|
# print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
# print('Params : {:.2f} M'.format(params / 1e6))
|
|
|