You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.2 KiB
77 lines
2.2 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpose.models import HourglassAENet, HourglassNet
|
|
|
|
|
|
def test_hourglass_backbone():
|
|
with pytest.raises(AssertionError):
|
|
# HourglassNet's num_stacks should larger than 0
|
|
HourglassNet(num_stacks=0)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# len(stage_channels) should equal len(stage_blocks)
|
|
HourglassNet(
|
|
stage_channels=[256, 256, 384, 384, 384],
|
|
stage_blocks=[2, 2, 2, 2, 2, 4])
|
|
|
|
with pytest.raises(AssertionError):
|
|
# len(stage_channels) should larger than downsample_times
|
|
HourglassNet(
|
|
downsample_times=5,
|
|
stage_channels=[256, 256, 384, 384, 384],
|
|
stage_blocks=[2, 2, 2, 2, 2])
|
|
|
|
# Test HourglassNet-52
|
|
model = HourglassNet(num_stacks=1)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 256, 256)
|
|
feat = model(imgs)
|
|
assert len(feat) == 1
|
|
assert feat[0].shape == torch.Size([1, 256, 64, 64])
|
|
|
|
# Test HourglassNet-104
|
|
model = HourglassNet(num_stacks=2)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 256, 256)
|
|
feat = model(imgs)
|
|
assert len(feat) == 2
|
|
assert feat[0].shape == torch.Size([1, 256, 64, 64])
|
|
assert feat[1].shape == torch.Size([1, 256, 64, 64])
|
|
|
|
|
|
def test_hourglass_ae_backbone():
|
|
with pytest.raises(AssertionError):
|
|
# HourglassAENet's num_stacks should larger than 0
|
|
HourglassAENet(num_stacks=0)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# len(stage_channels) should larger than downsample_times
|
|
HourglassAENet(
|
|
downsample_times=5, stage_channels=[256, 256, 384, 384, 384])
|
|
|
|
# num_stack=1
|
|
model = HourglassAENet(num_stacks=1)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 256, 256)
|
|
feat = model(imgs)
|
|
assert len(feat) == 1
|
|
assert feat[0].shape == torch.Size([1, 34, 64, 64])
|
|
|
|
# num_stack=2
|
|
model = HourglassAENet(num_stacks=2)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 256, 256)
|
|
feat = model(imgs)
|
|
assert len(feat) == 2
|
|
assert feat[0].shape == torch.Size([1, 34, 64, 64])
|
|
assert feat[1].shape == torch.Size([1, 34, 64, 64])
|
|
|