You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
3.8 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Dict, List, Optional, Union
from mmpose.apis import (get_track_id, inference_top_down_pose_model,
init_pose_model)
from ..utils import Message
from .builder import NODES
from .node import Node
@NODES.register_module()
class TopDownPoseEstimatorNode(Node):
def __init__(self,
name: str,
model_config: str,
model_checkpoint: str,
input_buffer: str,
output_buffer: Union[str, List[str]],
enable_key: Optional[Union[str, int]] = None,
enable: bool = True,
device: str = 'cuda:0',
cls_ids: Optional[List] = None,
cls_names: Optional[List] = None,
bbox_thr: float = 0.5):
super().__init__(name=name, enable_key=enable_key, enable=enable)
# Init model
self.model_config = model_config
self.model_checkpoint = model_checkpoint
self.device = device.lower()
self.cls_ids = cls_ids
self.cls_names = cls_names
self.bbox_thr = bbox_thr
# Init model
self.model = init_pose_model(
self.model_config,
self.model_checkpoint,
device=self.device.lower())
# Store history for pose tracking
self.track_info = {
'next_id': 0,
'last_pose_preds': [],
'last_time': None
}
# Register buffers
self.register_input_buffer(input_buffer, 'input', essential=True)
self.register_output_buffer(output_buffer)
def bypass(self, input_msgs):
return input_msgs['input']
def process(self, input_msgs: Dict[str, Message]) -> Message:
input_msg = input_msgs['input']
img = input_msg.get_image()
det_results = input_msg.get_detection_results()
if det_results is None:
raise ValueError(
'No detection results are found in the frame message.'
f'{self.__class__.__name__} should be used after a '
'detector node.')
full_det_preds = []
for det_result in det_results:
det_preds = det_result['preds']
if self.cls_ids:
# Filter detection results by class ID
det_preds = [
p for p in det_preds if p['cls_id'] in self.cls_ids
]
elif self.cls_names:
# Filter detection results by class name
det_preds = [
p for p in det_preds if p['label'] in self.cls_names
]
full_det_preds.extend(det_preds)
# Inference pose
pose_preds, _ = inference_top_down_pose_model(
self.model,
img,
full_det_preds,
bbox_thr=self.bbox_thr,
format='xyxy')
# Pose tracking
current_time = time.time()
if self.track_info['last_time'] is None:
fps = None
elif self.track_info['last_time'] >= current_time:
fps = None
else:
fps = 1.0 / (current_time - self.track_info['last_time'])
pose_preds, next_id = get_track_id(
pose_preds,
self.track_info['last_pose_preds'],
self.track_info['next_id'],
use_oks=False,
tracking_thr=0.3,
use_one_euro=True,
fps=fps)
self.track_info['next_id'] = next_id
self.track_info['last_pose_preds'] = pose_preds.copy()
self.track_info['last_time'] = current_time
pose_result = {
'preds': pose_preds,
'model_cfg': self.model.cfg.copy(),
}
input_msg.add_pose_result(pose_result, tag=self.name)
return input_msg