Skip to content

Commit

Permalink
Merge pull request #2 from ancestor-mithril/dev
Browse files Browse the repository at this point in the history
Diversified measuring and printing options.

* Changed behavior when printing elapsed time
    * Previous behavior: elapsed time is either printed to stdout, written to file or to logger.
    * Current behavior: printing, logging and writing are not mutually exclusive anymore. Each is controlled by its own parameter.
* Added new parameter for returning elapsed time in addition to the function's return value.
* Updated classifiers in pyproject.toml.
* Improved test speed.
  • Loading branch information
ancestor-mithril authored Apr 29, 2024
2 parents bedc7e0 + bcc760e commit 749266d
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 44 deletions.
48 changes: 35 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def batched_euclidean_distance(x: Tensor, y: Tensor) -> Tensor:
a = torch.rand((10000, 800))
b = torch.rand((12000, 800))
batched_euclidean_distance(a, b)
a = a.cuda()
b = b.cuda()
batched_euclidean_distance(a, b) # Cuda device is synchronized if function arguments are on device.

if torch.cuda.is_available():
a = a.cuda()
b = b.cuda()
batched_euclidean_distance(a, b) # Cuda device is synchronized if function arguments are on device.
```
Prints:
```
Expand All @@ -49,8 +51,10 @@ batched_euclidean_distance(CudaTensor[10000, 800], CudaTensor[12000, 800]) -> to
* `show_kwargs` (`bool`): If `True`, displays the keyword arguments according to `display_level`. Default: `False`.
* `display_level` (`int`): The level of verbosity used when printing function arguments ad keyword arguments. If `0`, prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays, tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`.
* `sep` (`str`): The separator used when printing function arguments and keyword arguments. Default: `', '`.
* `file_path` (`str`): If not `None`, writes the measurement at the end of the given file path. For thread safe file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`.
* `logger_name` (`str`): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`. See [Using a logger](#using-a-logger).
* `stdout` (`bool`): If `True`, writes the elapsed time to stdout. Default: `True`.
* `file_path` (`str`): If not `None`, writes the measurement at the end of the given file path. For thread safe file writing configure use `logger_name` instead. Default: `None`.
* `logger_name` (`str`): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction with `file_path`. Default: `None`. See [Using a logger](#using-a-logger).
* `return_time` (`bool`): If `True`, returns the elapsed time in addition to the wrapped function's return value. Default: `False`.
* `out` (`dict`): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as key. If the key already exists, adds the time to the existing value. Default: `None`. See [Storing the elapsed time in a dict](#storing-the-elapsed-time-in-a-dict).

2. `nested_timed` is similar to `timed`, however it is designed to work nicely with multiple timed functions that call each other, displaying both the total execution time and the difference after subtracting other timed functions on the same call stack. See [Nested timing decorator](#nested-timing-decorator).
Expand All @@ -72,7 +76,27 @@ def fibonacci(n: int) -> int:


fibonacci(10000)
# fibonacci() -> total time: 2114100ns
# fibonacci() -> total time: 1114100ns
```

Getting both the function's return value and the elapsed time.
```py
from timed_decorator.simple_timed import timed


@timed(return_time=True)
def fibonacci(n: int) -> int:
assert n > 0
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a


value, elapsed = fibonacci(10000)
print(f'10000th fibonacci number has {len(str(value))} digits. Calculating it took {elapsed}ns.')
# fibonacci() -> total time: 1001200ns
# 10000th fibonacci number has 2090 digits. Calculating it took 1001200ns.
```

Set `collect_gc=False` to disable pre-collection of garbage.
Expand All @@ -91,7 +115,7 @@ def fibonacci(n: int) -> int:


fibonacci(10000)
# fibonacci() -> total time: 2062400ns
# fibonacci() -> total time: 1062400ns
```

Using seconds instead of nanoseconds.
Expand All @@ -114,7 +138,7 @@ def recursive_fibonacci(n: int) -> int:


call_recursive_fibonacci(30)
# call_recursive_fibonacci() -> total time: 0.098s
# call_recursive_fibonacci() -> total time: 0.045s
```

Displaying function parameters:
Expand Down Expand Up @@ -305,7 +329,7 @@ logging.basicConfig()
logging.root.setLevel(logging.NOTSET)


@timed(logger_name='TEST_LOGGER')
@timed(logger_name='TEST_LOGGER', stdout=False)
def fn():
sleep(1)

Expand Down Expand Up @@ -333,7 +357,7 @@ logging.root.setLevel(logging.NOTSET)
logging.getLogger('TEST_LOGGER').addHandler(log_handler)


@timed(logger_name='TEST_LOGGER')
@timed(logger_name='TEST_LOGGER', stdout=False)
def fn():
sleep(1)

Expand All @@ -357,7 +381,7 @@ from timed_decorator.simple_timed import timed
ns = {}


@timed(out=ns)
@timed(out=ns, stdout=False)
def fn():
sleep(1)

Expand All @@ -369,8 +393,6 @@ print(ns)
```
Prints
```
fn() -> total time: 1000767300ns
{'fn': 1000767300}
fn() -> total time: 1000238800ns
{'fn': 2001006100}
```
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "timed-decorator"
version = "1.2.2"
version = "1.3.0"
#requires-python = ">=3.10"
requires-python = ">=3.7"
description = "A timing decorator for python functions."
Expand All @@ -15,6 +15,8 @@ maintainers = [
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]

Expand Down
34 changes: 22 additions & 12 deletions tests/test_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,22 @@ def sleeping_fn(x):

@nested_timed(collect_gc=False, use_seconds=True, precision=3)
def other_fn():
sleep(0.5)
sleeping_fn(0.5)
sleep(0.1)
sleeping_fn(0.1)

sleep(1)
sleeping_fn(1)
sleep(0.1)
sleeping_fn(0.1)
other_fn()
sleeping_fn(1)
sleeping_fn(0.1)

nested_fn()

def test_file_usage(self):
filename = 'file.txt'

@timed(file_path=filename)
@timed(file_path=filename, stdout=False)
def fn():
sleep(1)
sleep(0.5)

try:
fn()
Expand All @@ -86,9 +86,9 @@ def test_logger_usage(self):
logging.root.setLevel(logging.NOTSET)
logging.getLogger(logger_name).addHandler(log_handler)

@timed(logger_name=logger_name)
@timed(logger_name=logger_name, stdout=False)
def fn():
sleep(1)
sleep(0.5)

fn()
fn()
Expand All @@ -101,14 +101,24 @@ def fn():
def test_ns_output(self):
ns = {}

@timed(out=ns)
@timed(out=ns, stdout=False)
def fn():
sleep(1)
sleep(0.5)

fn()

self.assertIsInstance(ns[fn.__name__], int)
self.assertGreater(ns[fn.__name__], 1**9)
self.assertGreater(ns[fn.__name__], 1**9 / 2)

def test_return_time(self):
@timed(return_time=True, stdout=False)
def fn():
sleep(0.5)

_, elapsed = fn()

self.assertIsInstance(elapsed, int)
self.assertGreater(elapsed, 1**9 / 2)


if __name__ == '__main__':
Expand Down
16 changes: 10 additions & 6 deletions timed_decorator/nested_timed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ def nested_timed(collect_gc: bool = True,
show_kwargs: bool = False,
display_level: int = 1,
sep: str = ', ',
stdout: bool = True,
file_path: Union[str, None] = None,
logger_name: Union[str, None] = None,
return_time: bool = False,
out: dict = None):
"""
A nested timing decorator that measures the time elapsed during the function call and accounts for other decorators
Expand All @@ -40,20 +42,20 @@ def nested_timed(collect_gc: bool = True,
prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays,
tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`.
sep (str): The separator used when printing function arguments and keyword arguments. Default: `', '`.
stdout (bool): If `True`, writes the elapsed time to stdout. Default: `True`.
file_path (str): If not `None`, writes the measurement at the end of the given file path. For thread safe
file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both
`file_path` and `logger_name` are `None`, writes to stdout. Default: `None`.
file writing configure use `logger_name` instead. Default: `None`.
logger_name (str): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction
with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`.
with `file_path`. Default: `None`.
return_time (bool): If `True`, returns the elapsed time in addition to the wrapped function's return value.
Default: `False`.
out (dict): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as
key. If the key already exists, adds the time to the existing value. Default: `None`.
"""
assert file_path is None or logger_name is None

gc_collect = collect if collect_gc else nop
time_formatter = TimeFormatter(use_seconds, precision)
input_formatter = InputFormatter(show_args, show_kwargs, display_level, sep)
logger = Logger(file_path, logger_name)
logger = Logger(stdout, file_path, logger_name)
ns_out = write_mutable if out is not None else nop

def decorator(fn):
Expand Down Expand Up @@ -97,6 +99,8 @@ def wrap(*args, **kwargs):
logger('\t' * nested_level + f'{input_formatter(fn.__name__, *args, **kwargs)} '
f'-> total time: {time_formatter(elapsed)}, '
f'own time: {time_formatter(own_time)}')
if return_time:
return ret, elapsed
return ret

return wrap
Expand Down
16 changes: 10 additions & 6 deletions timed_decorator/simple_timed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def timed(collect_gc: bool = True,
show_kwargs: bool = False,
display_level: int = 1,
sep: str = ', ',
stdout: bool = True,
file_path: Union[str, None] = None,
logger_name: Union[str, None] = None,
return_time: bool = False,
out: dict = None):
"""
A simple timing decorator that measures the time elapsed during the function call and prints it.
Expand All @@ -36,20 +38,20 @@ def timed(collect_gc: bool = True,
prints the type of the parameters. If `1`, prints values for all primitive types, shapes for arrays,
tensors, dataframes and length for sequences. Otherwise, prints values for all parameters. Default: `1`.
sep (str): The separator used when printing function arguments and keyword arguments. Default: `', '`.
stdout (bool): If `True`, writes the elapsed time to stdout. Default: `True`.
file_path (str): If not `None`, writes the measurement at the end of the given file path. For thread safe
file writing configure use `logger_name` instead. Can't be used in conjunction with `logger_name`. If both
`file_path` and `logger_name` are `None`, writes to stdout. Default: `None`.
file writing configure use `logger_name` instead. Default: `None`.
logger_name (str): If not `None`, uses the given logger to print the measurement. Can't be used in conjunction
with `file_path`. If both `file_path` and `logger_name` are `None`, writes to stdout. Default: `None`.
with `file_path`. Default: `None`.
return_time (bool): If `True`, returns the elapsed time in addition to the wrapped function's return value.
Default: `False`.
out (dict): If not `None`, stores the elapsed time in nanoseconds in the given dict using the function name as
key. If the key already exists, adds the time to the existing value. Default: `None`.
"""
assert file_path is None or logger_name is None

gc_collect = collect if collect_gc else nop
time_formatter = TimeFormatter(use_seconds, precision)
input_formatter = InputFormatter(show_args, show_kwargs, display_level, sep)
logger = Logger(file_path, logger_name)
logger = Logger(stdout, file_path, logger_name)
ns_out = write_mutable if out is not None else nop

def decorator(fn):
Expand All @@ -71,6 +73,8 @@ def wrap(*args, **kwargs):
elapsed = end - start
ns_out(out, fn.__name__, elapsed)
logger(f'{input_formatter(fn.__name__, *args, **kwargs)} -> total time: {time_formatter(elapsed)}')
if return_time:
return ret, elapsed
return ret

return wrap
Expand Down
14 changes: 8 additions & 6 deletions timed_decorator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,22 @@ def __call__(self, nanoseconds):


class Logger:
def __init__(self, file_path: Union[str, None], logger_name: Union[str, None]):
assert file_path is None or logger_name is None

def __init__(self, stdout: bool, file_path: Union[str, None], logger_name: Union[str, None]):
self.stdout = stdout
self.file_path = file_path
self.logger_name = logger_name

def __call__(self, string: str):
if self.stdout:
print(string)

if self.file_path is not None:
with open(self.file_path, 'a') as f:
f.write(string + '\n')
elif self.logger_name is not None:

if self.logger_name is not None:
logging.getLogger(self.logger_name).info(string)
else:
print(string)



class InputFormatter:
Expand Down

0 comments on commit 749266d

Please sign in to comment.