You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
117 lines
3.7 KiB
117 lines
3.7 KiB
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from collections import defaultdict
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
from mmcv import Timer
|
|
|
|
|
|
class RunningAverage():
|
|
r"""A helper class to calculate running average in a sliding window.
|
|
|
|
Args:
|
|
window (int): The size of the sliding window.
|
|
"""
|
|
|
|
def __init__(self, window: int = 1):
|
|
self.window = window
|
|
self._data = []
|
|
|
|
def update(self, value):
|
|
"""Update a new data sample."""
|
|
self._data.append(value)
|
|
self._data = self._data[-self.window:]
|
|
|
|
def average(self):
|
|
"""Get the average value of current window."""
|
|
return np.mean(self._data)
|
|
|
|
|
|
class StopWatch:
|
|
r"""A helper class to measure FPS and detailed time consuming of each phase
|
|
in a video processing loop or similar scenarios.
|
|
|
|
Args:
|
|
window (int): The sliding window size to calculate the running average
|
|
of the time consuming.
|
|
|
|
Example:
|
|
>>> from mmpose.utils import StopWatch
|
|
>>> import time
|
|
>>> stop_watch = StopWatch(window=10)
|
|
>>> with stop_watch.timeit('total'):
|
|
>>> time.sleep(0.1)
|
|
>>> # 'timeit' support nested use
|
|
>>> with stop_watch.timeit('phase1'):
|
|
>>> time.sleep(0.1)
|
|
>>> with stop_watch.timeit('phase2'):
|
|
>>> time.sleep(0.2)
|
|
>>> time.sleep(0.2)
|
|
>>> report = stop_watch.report()
|
|
"""
|
|
|
|
def __init__(self, window=1):
|
|
self.window = window
|
|
self._record = defaultdict(partial(RunningAverage, window=self.window))
|
|
self._timer_stack = []
|
|
|
|
@contextmanager
|
|
def timeit(self, timer_name='_FPS_'):
|
|
"""Timing a code snippet with an assigned name.
|
|
|
|
Args:
|
|
timer_name (str): The unique name of the interested code snippet to
|
|
handle multiple timers and generate reports. Note that '_FPS_'
|
|
is a special key that the measurement will be in `fps` instead
|
|
of `millisecond`. Also see `report` and `report_strings`.
|
|
Default: '_FPS_'.
|
|
Note:
|
|
This function should always be used in a `with` statement, as shown
|
|
in the example.
|
|
"""
|
|
self._timer_stack.append((timer_name, Timer()))
|
|
try:
|
|
yield
|
|
finally:
|
|
timer_name, timer = self._timer_stack.pop()
|
|
self._record[timer_name].update(timer.since_start())
|
|
|
|
def report(self, key=None):
|
|
"""Report timing information.
|
|
|
|
Returns:
|
|
dict: The key is the timer name and the value is the \
|
|
corresponding average time consuming.
|
|
"""
|
|
result = {
|
|
name: r.average() * 1000.
|
|
for name, r in self._record.items()
|
|
}
|
|
|
|
if '_FPS_' in result:
|
|
result['_FPS_'] = 1000. / result.pop('_FPS_')
|
|
|
|
if key is None:
|
|
return result
|
|
return result[key]
|
|
|
|
def report_strings(self):
|
|
"""Report timing information in texture strings.
|
|
|
|
Returns:
|
|
list(str): Each element is the information string of a timed \
|
|
event, in format of '{timer_name}: {time_in_ms}'. \
|
|
Specially, if timer_name is '_FPS_', the result will \
|
|
be converted to fps.
|
|
"""
|
|
result = self.report()
|
|
strings = []
|
|
if '_FPS_' in result:
|
|
strings.append(f'FPS: {result["_FPS_"]:>5.1f}')
|
|
strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
|
|
return strings
|
|
|
|
def reset(self):
|
|
self._record = defaultdict(list)
|
|
self._active_timer_stack = []
|
|
|