|
|
@ -40,7 +40,7 @@ def main(): |
|
|
|
value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0) |
|
|
|
new_ckpt['state_dict'][key] = value |
|
|
|
|
|
|
|
torch.save(new_ckpt, os.path.join(args.targetPath, 'coco.pth')) |
|
|
|
torch.save(new_ckpt, os.path.join(args.target, 'coco.pth')) |
|
|
|
|
|
|
|
names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody'] |
|
|
|
num_keypoints = [14, 16, 17, 17, 133] |
|
|
@ -87,6 +87,17 @@ def main(): |
|
|
|
for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']: |
|
|
|
new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]] |
|
|
|
|
|
|
|
# remove unnecessary part in the state dict |
|
|
|
for j in range(5): |
|
|
|
# remove associate part |
|
|
|
for tensor_name in weight_names: |
|
|
|
new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{j}')) |
|
|
|
# remove expert part |
|
|
|
keys = new_ckpt['state_dict'].keys() |
|
|
|
for key in list(keys): |
|
|
|
if 'expert' in keys: |
|
|
|
new_ckpt['state_dict'].pop(key) |
|
|
|
|
|
|
|
torch.save(new_ckpt, os.path.join(args.target, f'{names[i]}.pth')) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|