You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
106 lines
3.2 KiB
106 lines
3.2 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from functools import wraps
|
|
from queue import Queue
|
|
from typing import Dict, List, Optional
|
|
|
|
from mmcv import is_seq_of
|
|
|
|
__all__ = ['BufferManager']
|
|
|
|
|
|
def check_buffer_registered(exist=True):
|
|
|
|
def wrapper(func):
|
|
|
|
@wraps(func)
|
|
def wrapped(manager, name, *args, **kwargs):
|
|
if exist:
|
|
# Assert buffer exist
|
|
if name not in manager:
|
|
raise ValueError(f'Fail to call {func.__name__}: '
|
|
f'buffer "{name}" is not registered.')
|
|
else:
|
|
# Assert buffer not exist
|
|
if name in manager:
|
|
raise ValueError(f'Fail to call {func.__name__}: '
|
|
f'buffer "{name}" is already registered.')
|
|
return func(manager, name, *args, **kwargs)
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|
|
|
|
|
|
class Buffer(Queue):
|
|
|
|
def put_force(self, item):
|
|
"""Force to put an item into the buffer.
|
|
|
|
If the buffer is already full, the earliest item in the buffer will be
|
|
remove to make room for the incoming item.
|
|
"""
|
|
with self.mutex:
|
|
if self.maxsize > 0:
|
|
while self._qsize() >= self.maxsize:
|
|
_ = self._get()
|
|
self.unfinished_tasks -= 1
|
|
|
|
self._put(item)
|
|
self.unfinished_tasks += 1
|
|
self.not_empty.notify()
|
|
|
|
|
|
class BufferManager():
|
|
|
|
def __init__(self,
|
|
buffer_type: type = Buffer,
|
|
buffers: Optional[Dict] = None):
|
|
self.buffer_type = buffer_type
|
|
if buffers is None:
|
|
self._buffers = {}
|
|
else:
|
|
if is_seq_of(list(buffers.values()), buffer_type):
|
|
self._buffers = buffers.copy()
|
|
else:
|
|
raise ValueError('The values of buffers should be instance '
|
|
f'of {buffer_type}')
|
|
|
|
def __contains__(self, name):
|
|
return name in self._buffers
|
|
|
|
@check_buffer_registered(False)
|
|
def register_buffer(self, name, maxsize=0):
|
|
self._buffers[name] = self.buffer_type(maxsize)
|
|
|
|
@check_buffer_registered()
|
|
def put(self, name, item, block=True, timeout=None):
|
|
self._buffers[name].put(item, block, timeout)
|
|
|
|
@check_buffer_registered()
|
|
def put_force(self, name, item):
|
|
self._buffers[name].put_force(item)
|
|
|
|
@check_buffer_registered()
|
|
def get(self, name, block=True, timeout=None):
|
|
return self._buffers[name].get(block, timeout)
|
|
|
|
@check_buffer_registered()
|
|
def is_empty(self, name):
|
|
return self._buffers[name].empty()
|
|
|
|
@check_buffer_registered()
|
|
def is_full(self, name):
|
|
return self._buffers[name].full()
|
|
|
|
def get_sub_manager(self, buffer_names: List[str]):
|
|
buffers = {name: self._buffers[name] for name in buffer_names}
|
|
return BufferManager(self.buffer_type, buffers)
|
|
|
|
def get_info(self):
|
|
buffer_info = {}
|
|
for name, buffer in self._buffers.items():
|
|
buffer_info[name] = {
|
|
'size': buffer.size,
|
|
'maxsize': buffer.maxsize
|
|
}
|
|
return buffer_info
|
|
|