Source code for mechanics_dsl.solver_numba

"""
Numba-accelerated numerical solver for MechanicsDSL

Provides JIT-compiled solvers for significant performance improvements
on numerical ODE integration.

Usage:
    simulator = NumbaSimulator(symbolic_engine)
    simulator.compile_equations(accelerations, coordinates)
    solution = simulator.simulate_numba(t_span, num_points=1000)
"""

from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import sympy as sp

from .utils import logger

# Try to import numba, fallback to pure Python if not available
try:
    import warnings

    from numba import njit
    from numba.core.errors import NumbaWarning

    warnings.filterwarnings("ignore", category=NumbaWarning)
    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False
    logger.warning("Numba not available. Install with: pip install numba")

    # Define no-op decorators for fallback
    def njit(*args, **kwargs):
        if len(args) == 1 and callable(args[0]):
            return args[0]
        return lambda f: f

    def jit(*args, **kwargs):
        if len(args) == 1 and callable(args[0]):
            return args[0]
        return lambda f: f

    def prange(*args, **kwargs):
        return range(*args)


# ============================================================================
# JIT-Compiled Integration Kernels
# ============================================================================


@njit(cache=True)
def _rk4_step(
    y: np.ndarray, dt: float, dydt_func: Callable, t: float, params: np.ndarray
) -> np.ndarray:
    """
    Single RK4 integration step.

    Args:
        y: Current state vector
        dt: Time step
        dydt_func: Derivative function f(t, y, params) -> dy/dt
        t: Current time
        params: Parameter array

    Returns:
        Updated state vector
    """
    k1 = dydt_func(t, y, params)
    k2 = dydt_func(t + dt / 2, y + dt * k1 / 2, params)
    k3 = dydt_func(t + dt / 2, y + dt * k2 / 2, params)
    k4 = dydt_func(t + dt, y + dt * k3, params)

    return y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6


@njit(cache=True)
def _euler_step(
    y: np.ndarray, dt: float, dydt_func: Callable, t: float, params: np.ndarray
) -> np.ndarray:
    """Simple Euler integration step."""
    return y + dt * dydt_func(t, y, params)


@njit(cache=True)
def _rk45_adaptive_step(
    y: np.ndarray,
    dt: float,
    dydt_func: Callable,
    t: float,
    params: np.ndarray,
    rtol: float,
    atol: float,
) -> Tuple[np.ndarray, float, bool]:
    """
    Adaptive RK45 (Dormand-Prince) step with error estimation.

    Returns:
        (new_y, new_dt, accepted) - updated state, suggested next dt, and whether step was accepted
    """
    # Dormand-Prince coefficients
    c2, c3, c4, c5, c6 = 1 / 5, 3 / 10, 4 / 5, 8 / 9, 1.0

    a21 = 1 / 5
    a31, a32 = 3 / 40, 9 / 40
    a41, a42, a43 = 44 / 45, -56 / 15, 32 / 9
    a51, a52, a53, a54 = 19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729
    a61, a62, a63, a64, a65 = 9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656

    b1, b3, b4, b5, b6 = 35 / 384, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84
    e1, e3, e4, e5, e6, e7 = 71 / 57600, -71 / 16695, 71 / 1920, -17253 / 339200, 22 / 525, -1 / 40

    # Compute RK stages
    k1 = dydt_func(t, y, params)
    k2 = dydt_func(t + c2 * dt, y + dt * a21 * k1, params)
    k3 = dydt_func(t + c3 * dt, y + dt * (a31 * k1 + a32 * k2), params)
    k4 = dydt_func(t + c4 * dt, y + dt * (a41 * k1 + a42 * k2 + a43 * k3), params)
    k5 = dydt_func(t + c5 * dt, y + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), params)
    k6 = dydt_func(
        t + c6 * dt, y + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), params
    )

    # 5th order solution
    y_new = y + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6)

    # Error estimate
    k7 = dydt_func(t + dt, y_new, params)
    error = dt * (e1 * k1 + e3 * k3 + e4 * k4 + e5 * k5 + e6 * k6 + e7 * k7)

    # Compute error norm
    scale = atol + rtol * np.maximum(np.abs(y), np.abs(y_new))
    err_norm = np.sqrt(np.mean((error / scale) ** 2))

    # Determine if step is accepted
    accepted = err_norm <= 1.0

    # Compute new step size
    safety = 0.9
    min_factor = 0.2
    max_factor = 5.0

    if err_norm == 0:
        factor = max_factor
    else:
        factor = safety * err_norm ** (-0.2)
        factor = max(min_factor, min(max_factor, factor))

    new_dt = dt * factor

    if accepted:
        return y_new, new_dt, True
    else:
        return y, new_dt, False


# ============================================================================
# JIT-Compiled ODE Functions
# ============================================================================


def create_numba_ode_function(
    accelerations: Dict[str, sp.Expr], coordinates: List[str], parameter_names: List[str]
) -> Callable:
    """
    Create a JIT-compiled ODE function from symbolic expressions.

    This generates a Numba-compatible function that can be used with
    the JIT-compiled integrators.

    Args:
        accelerations: Dictionary of {coord_ddot: expression}
        coordinates: List of coordinate names
        parameter_names: List of parameter names

    Returns:
        JIT-compiled ODE function f(t, y, params) -> dydt
    """
    from sympy.utilities.lambdify import lambdify

    # Build symbol list for lambdify
    t_sym = sp.Symbol("t")
    coord_symbols = []
    for coord in coordinates:
        coord_symbols.append(sp.Symbol(coord, real=True))
        coord_symbols.append(sp.Symbol(f"{coord}_dot", real=True))

    param_symbols = [sp.Symbol(p) for p in parameter_names]

    all_symbols = [t_sym] + coord_symbols + param_symbols

    # Create lambdified functions for each acceleration
    accel_funcs = []
    for coord in coordinates:
        accel_key = f"{coord}_ddot"
        if accel_key in accelerations:
            expr = accelerations[accel_key]
            func = lambdify(all_symbols, expr, modules=["numpy"])
            accel_funcs.append(func)
        else:
            # Zero acceleration
            accel_funcs.append(lambda *args: 0.0)

    n_coords = len(coordinates)
    len(parameter_names)

    # This wrapper cannot be JIT-compiled directly due to closure limitations
    # but we can use it with Numba's objmode for hybrid compilation
    def ode_func(t: float, y: np.ndarray, params: np.ndarray) -> np.ndarray:
        """ODE function: dy/dt = f(t, y, params)"""
        dydt = np.zeros(len(y))

        # Build argument list: t, coord1, coord1_dot, coord2, ...
        args = [t] + list(y) + list(params)

        for i in range(n_coords):
            # d(coord)/dt = coord_dot
            dydt[2 * i] = y[2 * i + 1]
            # d(coord_dot)/dt = acceleration
            dydt[2 * i + 1] = accel_funcs[i](*args)

        return dydt

    return ode_func


# ============================================================================
# NumbaSimulator Class
# ============================================================================


[docs] class NumbaSimulator: """ Numba-accelerated numerical simulator for MechanicsDSL. Provides significant speedups over SciPy's solve_ivp for simple to moderately complex systems. Example: >>> from mechanics_dsl.solver_numba import NumbaSimulator >>> sim = NumbaSimulator(symbolic_engine) >>> sim.compile_equations(accelerations, coordinates) >>> solution = sim.simulate_numba(t_span=(0, 10), num_points=1000) """ def __init__(self, symbolic_engine=None): """ Initialize the Numba simulator. Args: symbolic_engine: Optional SymbolicEngine instance """ self.symbolic = symbolic_engine self.parameters: Dict[str, float] = {} self.initial_conditions: Dict[str, float] = {} self.coordinates: List[str] = [] self._ode_func: Optional[Callable] = None self._param_array: Optional[np.ndarray] = None self._is_compiled: bool = False if not HAS_NUMBA: logger.warning("Numba not available. Falling back to pure NumPy.")
[docs] def set_parameters(self, params: Dict[str, float]) -> None: """Set physical parameters.""" self.parameters.update(params) self._update_param_array()
[docs] def set_initial_conditions(self, conditions: Dict[str, float]) -> None: """Set initial conditions.""" self.initial_conditions.update(conditions)
def _update_param_array(self) -> None: """Update the parameter array for JIT functions.""" if self.parameters: self._param_array = np.array(list(self.parameters.values()), dtype=np.float64)
[docs] def compile_equations(self, accelerations: Dict[str, sp.Expr], coordinates: List[str]) -> None: """ Compile symbolic equations to Numba-compatible functions. Args: accelerations: Dictionary of {coord_ddot: symbolic_expression} coordinates: List of coordinate names """ self.coordinates = coordinates param_names = list(self.parameters.keys()) self._ode_func = create_numba_ode_function(accelerations, coordinates, param_names) self._is_compiled = True logger.info(f"Compiled {len(coordinates)} coordinates for Numba solver")
def _get_initial_state(self) -> np.ndarray: """Get initial state vector from initial conditions.""" state = [] for coord in self.coordinates: state.append(self.initial_conditions.get(coord, 0.0)) state.append(self.initial_conditions.get(f"{coord}_dot", 0.0)) return np.array(state, dtype=np.float64)
[docs] def simulate_numba( self, t_span: Tuple[float, float], num_points: int = 1000, method: str = "rk4", rtol: float = 1e-6, atol: float = 1e-9, ) -> Dict: """ Run simulation using Numba-accelerated integrator. Args: t_span: (t_start, t_end) time interval num_points: Number of output points method: Integration method ('euler', 'rk4', 'rk45') rtol: Relative tolerance (for adaptive methods) atol: Absolute tolerance (for adaptive methods) Returns: Dictionary with 't' and 'y' arrays """ if not self._is_compiled: raise RuntimeError("Equations not compiled. Call compile_equations first.") t_start, t_end = t_span y0 = self._get_initial_state() if self._param_array is None: self._update_param_array() params = self._param_array if self._param_array is not None else np.array([]) if method == "rk45": return self._integrate_adaptive(t_start, t_end, y0, params, num_points, rtol, atol) else: return self._integrate_fixed(t_start, t_end, y0, params, num_points, method)
def _integrate_fixed( self, t_start: float, t_end: float, y0: np.ndarray, params: np.ndarray, num_points: int, method: str, ) -> Dict: """Fixed-step integration (Euler or RK4).""" t_eval = np.linspace(t_start, t_end, num_points) dt = (t_end - t_start) / (num_points - 1) y = np.zeros((len(y0), num_points)) y[:, 0] = y0 current_y = y0.copy() for i in range(1, num_points): t = t_eval[i - 1] if method == "euler": dydt = self._ode_func(t, current_y, params) current_y = current_y + dt * dydt else: # rk4 k1 = self._ode_func(t, current_y, params) k2 = self._ode_func(t + dt / 2, current_y + dt * k1 / 2, params) k3 = self._ode_func(t + dt / 2, current_y + dt * k2 / 2, params) k4 = self._ode_func(t + dt, current_y + dt * k3, params) current_y = current_y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 y[:, i] = current_y return {"t": t_eval, "y": y, "success": True} def _integrate_adaptive( self, t_start: float, t_end: float, y0: np.ndarray, params: np.ndarray, num_points: int, rtol: float, atol: float, ) -> Dict: """Adaptive RK45 integration with dense output.""" # Use adaptive stepping internally, then interpolate to output points t_list = [t_start] y_list = [y0.copy()] t = t_start y = y0.copy() dt = (t_end - t_start) / 100 # Initial step size guess dt_min = (t_end - t_start) / 1e6 dt_max = (t_end - t_start) / 10 max_steps = num_points * 100 steps = 0 while t < t_end and steps < max_steps: # Ensure we don't step past t_end if t + dt > t_end: dt = t_end - t # RK45 step k1 = self._ode_func(t, y, params) k2 = self._ode_func(t + dt / 5, y + dt * k1 / 5, params) k3 = self._ode_func(t + 3 * dt / 10, y + dt * (3 * k1 / 40 + 9 * k2 / 40), params) k4 = self._ode_func( t + 4 * dt / 5, y + dt * (44 * k1 / 45 - 56 * k2 / 15 + 32 * k3 / 9), params ) k5 = self._ode_func( t + 8 * dt / 9, y + dt * (19372 * k1 / 6561 - 25360 * k2 / 2187 + 64448 * k3 / 6561 - 212 * k4 / 729), params, ) k6 = self._ode_func( t + dt, y + dt * ( 9017 * k1 / 3168 - 355 * k2 / 33 + 46732 * k3 / 5247 + 49 * k4 / 176 - 5103 * k5 / 18656 ), params, ) # 5th order solution y_new = y + dt * ( 35 * k1 / 384 + 500 * k3 / 1113 + 125 * k4 / 192 - 2187 * k5 / 6784 + 11 * k6 / 84 ) # Error estimate (difference between 4th and 5th order) k7 = self._ode_func(t + dt, y_new, params) error = dt * ( 71 * k1 / 57600 - 71 * k3 / 16695 + 71 * k4 / 1920 - 17253 * k5 / 339200 + 22 * k6 / 525 - k7 / 40 ) # Error norm scale = atol + rtol * np.maximum(np.abs(y), np.abs(y_new)) err_norm = np.sqrt(np.mean((error / scale) ** 2)) if err_norm <= 1.0: # Accept step t = t + dt y = y_new t_list.append(t) y_list.append(y.copy()) # Update step size if err_norm == 0: factor = 5.0 else: factor = 0.9 * err_norm ** (-0.2) factor = max(0.2, min(5.0, factor)) dt = np.clip(dt * factor, dt_min, dt_max) steps += 1 if steps >= max_steps: logger.warning(f"Max steps ({max_steps}) reached in adaptive integration") # Interpolate to requested output points t_internal = np.array(t_list) y_internal = np.array(y_list).T # Shape: (n_states, n_internal_points) t_eval = np.linspace(t_start, t_end, num_points) y_output = np.zeros((y0.shape[0], num_points)) for i in range(y0.shape[0]): y_output[i, :] = np.interp(t_eval, t_internal, y_internal[i, :]) return {"t": t_eval, "y": y_output, "success": True}
# ============================================================================ # Convenience Functions # ============================================================================
[docs] def is_numba_available() -> bool: """Check if Numba is available for JIT compilation.""" return HAS_NUMBA
def get_numba_version() -> Optional[str]: """Get Numba version if available.""" if HAS_NUMBA: import numba return numba.__version__ return None