You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
2.5 KiB

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
from .builder import NODES
from .node import Node
try:
from mmdet.apis import inference_detector, init_detector
has_mmdet = True
except (ImportError, ModuleNotFoundError):
has_mmdet = False
@NODES.register_module()
class DetectorNode(Node):
def __init__(self,
name: str,
model_config: str,
model_checkpoint: str,
input_buffer: str,
output_buffer: Union[str, List[str]],
enable_key: Optional[Union[str, int]] = None,
device: str = 'cuda:0'):
# Check mmdetection is installed
assert has_mmdet, 'Please install mmdet to run the demo.'
super().__init__(name=name, enable_key=enable_key, enable=True)
self.model_config = model_config
self.model_checkpoint = model_checkpoint
self.device = device.lower()
# Init model
self.model = init_detector(
self.model_config,
self.model_checkpoint,
device=self.device.lower())
# Register buffers
self.register_input_buffer(input_buffer, 'input', essential=True)
self.register_output_buffer(output_buffer)
def bypass(self, input_msgs):
return input_msgs['input']
def process(self, input_msgs):
input_msg = input_msgs['input']
img = input_msg.get_image()
preds = inference_detector(self.model, img)
det_result = self._post_process(preds)
input_msg.add_detection_result(det_result, tag=self.name)
return input_msg
def _post_process(self, preds):
if isinstance(preds, tuple):
dets = preds[0]
segms = preds[1]
else:
dets = preds
segms = [None] * len(dets)
assert len(dets) == len(self.model.CLASSES)
assert len(segms) == len(self.model.CLASSES)
result = {'preds': [], 'model_cfg': self.model.cfg.copy()}
for i, (cls_name, bboxes,
masks) in enumerate(zip(self.model.CLASSES, dets, segms)):
if masks is None:
masks = [None] * len(bboxes)
else:
assert len(masks) == len(bboxes)
preds_i = [{
'cls_id': i,
'label': cls_name,
'bbox': bbox,
'mask': mask,
} for (bbox, mask) in zip(bboxes, masks)]
result['preds'].extend(preds_i)
return result