# pylint: disable=missing-docstring
# pylint: enable=missing-docstring
# Copyright 2017 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import logging
import math
import time
LOGGER = logging.getLogger(__name__)
[docs]class lazy_import: # pylint: disable=invalid-name, too-few-public-methods
"""
Lazily import the given module.
Parameters
----------
module : str
Name of the module to import
"""
def __init__(self, module):
self.module = module
self._module = None
def __getattr__(self, name):
if self._module is None:
self._module = __import__(self.module)
return getattr(self._module, name)
[docs]class batch_iterable: # pylint: disable=invalid-name, too-few-public-methods
"""
Split an iterable into batches of a specified size.
Parameters
----------
iterable : iterable
Iterable to split into batches.
batch_size : int
Size of each batch.
transpose : bool
Whether to transpose each batch.
"""
def __init__(self, iterable, batch_size, transpose=False):
self.iterable = iterable
if batch_size <= 0:
raise ValueError("`batch_size` must be positive but got '%s'" % batch_size)
self.batch_size = batch_size
self.transpose = transpose
def __len__(self):
return math.ceil(len(self.iterable) / self.batch_size)
def __iter__(self):
batch = []
for item in self.iterable:
batch.append(item)
if len(batch) == self.batch_size:
yield tuple(zip(*batch)) if self.transpose else batch
batch = []
if batch:
yield tuple(zip(*batch)) if self.transpose else batch
[docs]class Profiler: # pylint: disable=too-few-public-methods
"""
Callback for profiling computational graphs.
Attributes
----------
times : dict[Operation, float]
Mapping from operations to execution times.
"""
def __init__(self):
self.times = {}
[docs] def get_slow_operations(self, num_operations=None):
"""
Get the slowest operations.
Parameters
----------
num_operations : int or None
Maximum number of operations to return or `None`
Returns
-------
times : collections.OrderedDict
Mapping of execution times keyed by operations.
"""
items = list(sorted(self.times.items(), key=lambda x: x[1], reverse=True))
if num_operations is not None:
items = items[:num_operations]
return collections.OrderedDict(items)
@contextlib.contextmanager
def __call__(self, operation, context):
start = time.time()
yield
self.times[operation] = time.time() - start
def __str__(self):
return "\n".join(['%s: %s' % item for item in self.get_slow_operations(10).items()])
@contextlib.contextmanager
def _noop_callback(*_):
yield
[docs]def deprecated(func): # pragma: no cover
"""
Mark a callable as deprecated.
"""
def _wrapper(*args, **kwargs):
LOGGER.warning("%s is deprecated", func)
return func(*args, **kwargs)
return _wrapper