feat: Time+memory tracking
This commit is contained in:
parent
59722acf76
commit
ee4d94e157
7 changed files with 456 additions and 0 deletions
175
src/utils/benchmark.py
Normal file
175
src/utils/benchmark.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Utilities functions for benchmarking."""
|
||||
import json
|
||||
import string
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from os import getpid, path
|
||||
from pathlib import Path
|
||||
from random import choices
|
||||
from subprocess import DEVNULL, PIPE, CalledProcessError, TimeoutExpired, run
|
||||
from timeit import timeit
|
||||
from typing import Callable
|
||||
|
||||
from memray import Tracker
|
||||
|
||||
from ..utils.benchmark_dataclasses import BenchmarkItem, BenchmarkResult
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def get_commit_hash() -> str:
|
||||
"""
|
||||
Get the commit hash of the current git repository.
|
||||
|
||||
If not working in a git repository, return a random string that looks like a commit hash.
|
||||
"""
|
||||
try:
|
||||
return run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=PIPE,
|
||||
stderr=DEVNULL,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Could not determine the commit hash. Are you using a git repository?:\n%s",
|
||||
e,
|
||||
)
|
||||
log.error("Using a random string as commit hash.")
|
||||
return "".join(choices(string.hexdigits[:-6], k=40))
|
||||
|
||||
|
||||
def init_stat_file(stat_file: Path, header: str) -> int:
|
||||
"""Initialize a statistics file with a header."""
|
||||
# Check if the parent directory exists
|
||||
stat_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if the file exists
|
||||
if stat_file.exists():
|
||||
# Nothing left to do
|
||||
return 0
|
||||
|
||||
# Initialize the file by writing the header to it.
|
||||
log.debug("Initializing statistics file %s", stat_file)
|
||||
stat_file.touch()
|
||||
stat_file.write_text(f"{header}\n", encoding="utf-8")
|
||||
return 1
|
||||
|
||||
|
||||
def track_time_memory(task: Callable, result: BenchmarkResult, mem_file: Path, mem_json_file: Path):
|
||||
"""Track the time and memory consumption of a task."""
|
||||
|
||||
def task_with_result():
|
||||
result.value = task()
|
||||
|
||||
# Measure memory consumption
|
||||
with Tracker(file_name=mem_file, native_traces=True, follow_fork=True, memory_interval_ms=1):
|
||||
try:
|
||||
# Measure runtime
|
||||
result.runtime = timeit(task_with_result, number=1, globals=globals())
|
||||
except BaseException as e:
|
||||
log.error("Error while timing the program:\n%s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
# Convert binary memory file into JSON.
|
||||
try:
|
||||
run(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"memray",
|
||||
"stats",
|
||||
"--json",
|
||||
"--num-largest",
|
||||
"1",
|
||||
"--output",
|
||||
mem_json_file,
|
||||
mem_file,
|
||||
],
|
||||
check=True,
|
||||
timeout=100,
|
||||
stdout=DEVNULL,
|
||||
)
|
||||
# Parse JSON to get peak_memory
|
||||
mem_results = json.loads(mem_json_file.read_text(encoding="utf-8"))
|
||||
result.peak_memory = mem_results["metadata"]["peak_memory"]
|
||||
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Something went wrong while processing the memray memory file %s:\n%s",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
except TimeoutExpired as e:
|
||||
log.error(
|
||||
"Timeout expired while processing the memray memory file %s:\n%s}",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def execute_benchmark(
|
||||
benchmark_item: BenchmarkItem,
|
||||
results_dir: str | Path,
|
||||
timeout: int = 100,
|
||||
) -> BenchmarkResult:
|
||||
"""Execute a benchmark and track its runtime and peak memory consumption."""
|
||||
mem_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.mem"))
|
||||
mem_json_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.json"))
|
||||
|
||||
result = BenchmarkResult(benchmark_item)
|
||||
|
||||
try:
|
||||
# Time and track memory usage
|
||||
# Kill after timeout in seconds
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: track_time_memory(
|
||||
lambda: benchmark_item.task(**benchmark_item.arguments), result, mem_file, mem_json_file
|
||||
)
|
||||
)
|
||||
executed_result = future.result(timeout=timeout)
|
||||
|
||||
if executed_result is not None:
|
||||
result = executed_result
|
||||
|
||||
log.info(
|
||||
"PID %d: %s finished [%.6f seconds, %d bytes]",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
result.runtime,
|
||||
result.peak_memory,
|
||||
)
|
||||
except TimeoutError:
|
||||
log.error("Timeout expired while running the benchmark_suite, cleaning up now.")
|
||||
|
||||
log.info(
|
||||
"PID %d: %s failed after timeout (%d seconds)",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
timeout,
|
||||
)
|
||||
finally:
|
||||
# Clean up memory dump file to save disk space.
|
||||
mem_file.unlink()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import hydra
|
||||
|
||||
# Dummy example, read the contents of the dataset
|
||||
def _read_contents(filename):
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
log.info("Dataset content: %s", f.read())
|
||||
|
||||
def _read_contents_wrapper(cfg):
|
||||
return _read_contents(cfg.dataset.path)
|
||||
|
||||
hydra_wrapped = hydra.main(config_path="../../config", config_name="config", version_base="1.2")(
|
||||
_read_contents_wrapper
|
||||
)()
|
||||
79
src/utils/benchmark_dataclasses.py
Normal file
79
src/utils/benchmark_dataclasses.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Benchmark data classes.
|
||||
|
||||
This module contains the BenchmarkResult class which is used to store and print the results of a
|
||||
benchmark_suite.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkItem:
|
||||
"""A class used to represent a benchmark_suite (iteration)."""
|
||||
|
||||
task: Callable
|
||||
arguments: dict
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkItem object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""
|
||||
Format the method as if it were a function call.
|
||||
"""
|
||||
method_name = self.task.__name__
|
||||
arguments = ", ".join(
|
||||
f'{key}={str(value)[:15]}'
|
||||
for key, value in self.arguments.items()
|
||||
)
|
||||
return f"{method_name}({arguments})"
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite item to be printed to a .dat file.
|
||||
"""
|
||||
# Flatten out arguments
|
||||
values = list(self.__dict__.values())
|
||||
values[1:2] = values[1].values()
|
||||
|
||||
return " ".join(map(str, values))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
return " ".join(self.__dict__.keys())
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkResult:
|
||||
"""A class used to represent the result of a benchmark_suite."""
|
||||
|
||||
benchmark_item: BenchmarkItem
|
||||
runtime: float = 0
|
||||
peak_memory: int = 0
|
||||
value: Any = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkResult object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite result to be printed to a .dat file.
|
||||
"""
|
||||
return " ".join(map(str, self.__dict__.values()))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
# Get header of the BenchmarkItem
|
||||
keys = list(self.__annotations__.keys())
|
||||
keys[0:1] = self.benchmark_item.__annotations__.keys()
|
||||
keys[1:2] = self.benchmark_item.arguments.keys()
|
||||
|
||||
return " ".join(keys)
|
||||
Reference in a new issue