Browse Source

Merge pull request #79 from seaman1900/patch-2

modify the model_split.py file
233
Yufei Xu 2 years ago
committed by GitHub
parent
commit
d521645279
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 13
      tools/model_split.py

13
tools/model_split.py

@ -40,7 +40,7 @@ def main():
value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0)
new_ckpt['state_dict'][key] = value
torch.save(new_ckpt, os.path.join(args.targetPath, 'coco.pth'))
torch.save(new_ckpt, os.path.join(args.target, 'coco.pth'))
names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody']
num_keypoints = [14, 16, 17, 17, 133]
@ -87,6 +87,17 @@ def main():
for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']:
new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]]
# remove unnecessary part in the state dict
for j in range(5):
# remove associate part
for tensor_name in weight_names:
new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{j}'))
# remove expert part
keys = new_ckpt['state_dict'].keys()
for key in list(keys):
if 'expert' in keys:
new_ckpt['state_dict'].pop(key)
torch.save(new_ckpt, os.path.join(args.target, f'{names[i]}.pth'))
if __name__ == '__main__':

Loading…
Cancel
Save