From e87121366d71154c5162b85fbb0d63ac188c2288 Mon Sep 17 00:00:00 2001 From: seaman1900 <48203464+seaman1900@users.noreply.github.com> Date: Sat, 25 Feb 2023 11:56:23 +0800 Subject: [PATCH] modify the model_split.py file ``` shell python demo/top_down_img_demo.py configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/ViTPose_base_wholebody_256x192.py target/wholebody.pth --img-root tests/data/coco/ --json-file tests/data/coco/test_coco.json --out-img-root vis_results ``` When I use the previous model_split.py, above command occurs error, and the modified version can work properly. --- tools/model_split.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tools/model_split.py b/tools/model_split.py index 3cfe13f..928380a 100644 --- a/tools/model_split.py +++ b/tools/model_split.py @@ -40,7 +40,7 @@ def main(): value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0) new_ckpt['state_dict'][key] = value - torch.save(new_ckpt, os.path.join(args.targetPath, 'coco.pth')) + torch.save(new_ckpt, os.path.join(args.target, 'coco.pth')) names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody'] num_keypoints = [14, 16, 17, 17, 133] @@ -86,8 +86,19 @@ def main(): for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']: new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]] - + + # remove unnecessary part in the state dict + for j in range(5): + # remove associate part + for tensor_name in weight_names: + new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{j}')) + # remove expert part + keys = new_ckpt['state_dict'].keys() + for key in list(keys): + if 'expert' in keys: + new_ckpt['state_dict'].pop(key) + torch.save(new_ckpt, os.path.join(args.target, f'{names[i]}.pth')) if __name__ == '__main__': - main() \ No newline at end of file + main()