Source code for mechanics_dsl.utils.profiling
"""
Profiling and performance monitoring for MechanicsDSL
"""
import platform
import signal
import threading
import time
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from .config import config
from .logging import logger
try:
import psutil
except ImportError:
psutil = None # Optional dependency
[docs]
class PerformanceMonitor:
"""Advanced performance monitoring with memory and timing tracking"""
def __init__(self):
self.metrics: Dict[str, List[float]] = defaultdict(list)
self.memory_snapshots: List[Dict[str, float]] = []
self.start_times: Dict[str, float] = {}
[docs]
def start_timer(self, name: str) -> None:
"""Start timing an operation with validation"""
if not isinstance(name, str) or not name:
logger.warning(
f"PerformanceMonitor.start_timer: invalid name '{name}', using 'unnamed'"
)
name = "unnamed"
if name in self.start_times:
logger.warning(
f"PerformanceMonitor.start_timer: timer '{name}' already running, overwriting"
)
self.start_times[name] = time.perf_counter()
[docs]
def stop_timer(self, name: str) -> float:
"""Stop timing and record duration with validation"""
if not isinstance(name, str) or not name:
logger.warning(f"PerformanceMonitor.stop_timer: invalid name '{name}'")
return 0.0
if name not in self.start_times:
logger.warning(f"PerformanceMonitor.stop_timer: timer '{name}' was not started")
return 0.0
try:
duration = time.perf_counter() - self.start_times[name]
if duration < 0:
logger.warning(
f"PerformanceMonitor.stop_timer: negative duration for '{name}', clock issue?"
)
duration = 0.0
if duration > 86400: # More than 24 hours seems wrong
logger.warning(
f"PerformanceMonitor.stop_timer: suspiciously long duration {duration}s for '{name}'" # noqa: E501
)
self.metrics[name].append(duration)
del self.start_times[name]
return duration
except (KeyError, TypeError, ValueError) as e:
logger.error(f"PerformanceMonitor.stop_timer: error stopping timer '{name}': {e}")
return 0.0
[docs]
def get_memory_usage(self) -> Dict[str, float]:
"""Get current memory usage in MB"""
if psutil is None:
return {"rss": 0.0, "vms": 0.0, "percent": 0.0}
try:
process = psutil.Process()
mem_info = process.memory_info()
return {
"rss": mem_info.rss / 1024 / 1024, # Resident Set Size
"vms": mem_info.vms / 1024 / 1024, # Virtual Memory Size
"percent": process.memory_percent(),
}
except (AttributeError, Exception):
return {"rss": 0.0, "vms": 0.0, "percent": 0.0}
[docs]
def snapshot_memory(self, label: str = "") -> None:
"""Take a memory snapshot"""
mem = self.get_memory_usage()
mem["label"] = label
mem["timestamp"] = time.time()
self.memory_snapshots.append(mem)
[docs]
def get_stats(self, name: str) -> Dict[str, float]:
"""Get statistics for a metric with validation"""
if not isinstance(name, str) or not name:
logger.warning(f"PerformanceMonitor.get_stats: invalid name '{name}'")
return {}
if name not in self.metrics or not self.metrics[name]:
return {}
try:
values = self.metrics[name]
if not values:
return {}
# Filter out invalid values
valid_values = [v for v in values if isinstance(v, (int, float)) and np.isfinite(v)]
if not valid_values:
logger.warning(f"PerformanceMonitor.get_stats: no valid values for '{name}'")
return {}
return {
"count": len(valid_values),
"total": sum(valid_values),
"mean": float(np.mean(valid_values)),
"std": float(np.std(valid_values)),
"min": float(np.min(valid_values)),
"max": float(np.max(valid_values)),
}
except Exception as e:
logger.error(f"PerformanceMonitor.get_stats: error computing stats for '{name}': {e}")
return {}
[docs]
def reset(self) -> None:
"""Reset all metrics"""
self.metrics.clear()
self.memory_snapshots.clear()
self.start_times.clear()
# Global performance monitor
_perf_monitor = PerformanceMonitor()
[docs]
@contextmanager
def timeout(seconds: float):
"""
Cross-platform timeout context manager for timing out operations.
Uses signal.SIGALRM on Unix systems and threading.Timer on Windows.
Note: Threading-based timeout on Windows cannot interrupt CPU-bound operations.
Args:
seconds: Maximum time allowed (must be positive)
Raises:
TimeoutError: If operation exceeds time limit
ValueError: If seconds is not positive
"""
if not isinstance(seconds, (int, float)):
raise TypeError(f"seconds must be numeric, got {type(seconds).__name__}")
if seconds <= 0:
raise ValueError(f"seconds must be positive, got {seconds}")
if platform.system() == "Windows":
# Windows: Use threading.Timer (cannot interrupt CPU-bound operations)
timer: Optional[threading.Timer] = None
timeout_occurred = threading.Event()
def timeout_handler() -> None:
timeout_occurred.set()
raise TimeoutError(f"Operation timed out after {seconds} seconds")
timer = threading.Timer(seconds, timeout_handler)
timer.daemon = True
timer.start()
try:
yield
finally:
if timer is not None:
timer.cancel()
timer.join(timeout=0.1)
else:
# Unix: Use signal.SIGALRM (can interrupt operations)
def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError(f"Operation timed out after {seconds} seconds")
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(int(seconds))
try:
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
[docs]
def profile_function(func: Callable) -> Callable:
"""Decorator to profile function execution"""
@wraps(func)
def wrapper(*args, **kwargs):
if config.enable_profiling:
import cProfile
import pstats
from io import StringIO
profiler = cProfile.Profile()
try:
profiler.enable()
except ValueError:
# Another profiler is already active, skip profiling for this call
logger.debug(f"Skipping profiling for {func.__name__}: another profiler is active")
return func(*args, **kwargs)
try:
result = func(*args, **kwargs)
s = StringIO()
stats = pstats.Stats(profiler, stream=s)
stats.sort_stats("cumulative")
stats.print_stats(20) # Top 20 functions
logger.debug(f"\n{'='*70}\nProfile for {func.__name__}:\n{s.getvalue()}\n{'='*70}")
return result
finally:
profiler.disable()
else:
return func(*args, **kwargs)
return wrapper