"""
Numba-accelerated numerical solver for MechanicsDSL
Provides JIT-compiled solvers for significant performance improvements
on numerical ODE integration.
Usage:
simulator = NumbaSimulator(symbolic_engine)
simulator.compile_equations(accelerations, coordinates)
solution = simulator.simulate_numba(t_span, num_points=1000)
"""
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import sympy as sp
from .utils import logger
# Try to import numba, fallback to pure Python if not available
try:
import warnings
from numba import njit
from numba.core.errors import NumbaWarning
warnings.filterwarnings("ignore", category=NumbaWarning)
HAS_NUMBA = True
except ImportError:
HAS_NUMBA = False
logger.warning("Numba not available. Install with: pip install numba")
# Define no-op decorators for fallback
def njit(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
return lambda f: f
def jit(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
return lambda f: f
def prange(*args, **kwargs):
return range(*args)
# ============================================================================
# JIT-Compiled Integration Kernels
# ============================================================================
@njit(cache=True)
def _rk4_step(
y: np.ndarray, dt: float, dydt_func: Callable, t: float, params: np.ndarray
) -> np.ndarray:
"""
Single RK4 integration step.
Args:
y: Current state vector
dt: Time step
dydt_func: Derivative function f(t, y, params) -> dy/dt
t: Current time
params: Parameter array
Returns:
Updated state vector
"""
k1 = dydt_func(t, y, params)
k2 = dydt_func(t + dt / 2, y + dt * k1 / 2, params)
k3 = dydt_func(t + dt / 2, y + dt * k2 / 2, params)
k4 = dydt_func(t + dt, y + dt * k3, params)
return y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
@njit(cache=True)
def _euler_step(
y: np.ndarray, dt: float, dydt_func: Callable, t: float, params: np.ndarray
) -> np.ndarray:
"""Simple Euler integration step."""
return y + dt * dydt_func(t, y, params)
@njit(cache=True)
def _rk45_adaptive_step(
y: np.ndarray,
dt: float,
dydt_func: Callable,
t: float,
params: np.ndarray,
rtol: float,
atol: float,
) -> Tuple[np.ndarray, float, bool]:
"""
Adaptive RK45 (Dormand-Prince) step with error estimation.
Returns:
(new_y, new_dt, accepted) - updated state, suggested next dt, and whether step was accepted
"""
# Dormand-Prince coefficients
c2, c3, c4, c5, c6 = 1 / 5, 3 / 10, 4 / 5, 8 / 9, 1.0
a21 = 1 / 5
a31, a32 = 3 / 40, 9 / 40
a41, a42, a43 = 44 / 45, -56 / 15, 32 / 9
a51, a52, a53, a54 = 19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729
a61, a62, a63, a64, a65 = 9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656
b1, b3, b4, b5, b6 = 35 / 384, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84
e1, e3, e4, e5, e6, e7 = 71 / 57600, -71 / 16695, 71 / 1920, -17253 / 339200, 22 / 525, -1 / 40
# Compute RK stages
k1 = dydt_func(t, y, params)
k2 = dydt_func(t + c2 * dt, y + dt * a21 * k1, params)
k3 = dydt_func(t + c3 * dt, y + dt * (a31 * k1 + a32 * k2), params)
k4 = dydt_func(t + c4 * dt, y + dt * (a41 * k1 + a42 * k2 + a43 * k3), params)
k5 = dydt_func(t + c5 * dt, y + dt * (a51 * k1 + a52 * k2 + a53 * k3 + a54 * k4), params)
k6 = dydt_func(
t + c6 * dt, y + dt * (a61 * k1 + a62 * k2 + a63 * k3 + a64 * k4 + a65 * k5), params
)
# 5th order solution
y_new = y + dt * (b1 * k1 + b3 * k3 + b4 * k4 + b5 * k5 + b6 * k6)
# Error estimate
k7 = dydt_func(t + dt, y_new, params)
error = dt * (e1 * k1 + e3 * k3 + e4 * k4 + e5 * k5 + e6 * k6 + e7 * k7)
# Compute error norm
scale = atol + rtol * np.maximum(np.abs(y), np.abs(y_new))
err_norm = np.sqrt(np.mean((error / scale) ** 2))
# Determine if step is accepted
accepted = err_norm <= 1.0
# Compute new step size
safety = 0.9
min_factor = 0.2
max_factor = 5.0
if err_norm == 0:
factor = max_factor
else:
factor = safety * err_norm ** (-0.2)
factor = max(min_factor, min(max_factor, factor))
new_dt = dt * factor
if accepted:
return y_new, new_dt, True
else:
return y, new_dt, False
# ============================================================================
# JIT-Compiled ODE Functions
# ============================================================================
def create_numba_ode_function(
accelerations: Dict[str, sp.Expr], coordinates: List[str], parameter_names: List[str]
) -> Callable:
"""
Create a JIT-compiled ODE function from symbolic expressions.
This generates a Numba-compatible function that can be used with
the JIT-compiled integrators.
Args:
accelerations: Dictionary of {coord_ddot: expression}
coordinates: List of coordinate names
parameter_names: List of parameter names
Returns:
JIT-compiled ODE function f(t, y, params) -> dydt
"""
from sympy.utilities.lambdify import lambdify
# Build symbol list for lambdify
t_sym = sp.Symbol("t")
coord_symbols = []
for coord in coordinates:
coord_symbols.append(sp.Symbol(coord, real=True))
coord_symbols.append(sp.Symbol(f"{coord}_dot", real=True))
param_symbols = [sp.Symbol(p) for p in parameter_names]
all_symbols = [t_sym] + coord_symbols + param_symbols
# Create lambdified functions for each acceleration
accel_funcs = []
for coord in coordinates:
accel_key = f"{coord}_ddot"
if accel_key in accelerations:
expr = accelerations[accel_key]
func = lambdify(all_symbols, expr, modules=["numpy"])
accel_funcs.append(func)
else:
# Zero acceleration
accel_funcs.append(lambda *args: 0.0)
n_coords = len(coordinates)
len(parameter_names)
# This wrapper cannot be JIT-compiled directly due to closure limitations
# but we can use it with Numba's objmode for hybrid compilation
def ode_func(t: float, y: np.ndarray, params: np.ndarray) -> np.ndarray:
"""ODE function: dy/dt = f(t, y, params)"""
dydt = np.zeros(len(y))
# Build argument list: t, coord1, coord1_dot, coord2, ...
args = [t] + list(y) + list(params)
for i in range(n_coords):
# d(coord)/dt = coord_dot
dydt[2 * i] = y[2 * i + 1]
# d(coord_dot)/dt = acceleration
dydt[2 * i + 1] = accel_funcs[i](*args)
return dydt
return ode_func
# ============================================================================
# NumbaSimulator Class
# ============================================================================
[docs]
class NumbaSimulator:
"""
Numba-accelerated numerical simulator for MechanicsDSL.
Provides significant speedups over SciPy's solve_ivp for
simple to moderately complex systems.
Example:
>>> from mechanics_dsl.solver_numba import NumbaSimulator
>>> sim = NumbaSimulator(symbolic_engine)
>>> sim.compile_equations(accelerations, coordinates)
>>> solution = sim.simulate_numba(t_span=(0, 10), num_points=1000)
"""
def __init__(self, symbolic_engine=None):
"""
Initialize the Numba simulator.
Args:
symbolic_engine: Optional SymbolicEngine instance
"""
self.symbolic = symbolic_engine
self.parameters: Dict[str, float] = {}
self.initial_conditions: Dict[str, float] = {}
self.coordinates: List[str] = []
self._ode_func: Optional[Callable] = None
self._param_array: Optional[np.ndarray] = None
self._is_compiled: bool = False
if not HAS_NUMBA:
logger.warning("Numba not available. Falling back to pure NumPy.")
[docs]
def set_parameters(self, params: Dict[str, float]) -> None:
"""Set physical parameters."""
self.parameters.update(params)
self._update_param_array()
[docs]
def set_initial_conditions(self, conditions: Dict[str, float]) -> None:
"""Set initial conditions."""
self.initial_conditions.update(conditions)
def _update_param_array(self) -> None:
"""Update the parameter array for JIT functions."""
if self.parameters:
self._param_array = np.array(list(self.parameters.values()), dtype=np.float64)
[docs]
def compile_equations(self, accelerations: Dict[str, sp.Expr], coordinates: List[str]) -> None:
"""
Compile symbolic equations to Numba-compatible functions.
Args:
accelerations: Dictionary of {coord_ddot: symbolic_expression}
coordinates: List of coordinate names
"""
self.coordinates = coordinates
param_names = list(self.parameters.keys())
self._ode_func = create_numba_ode_function(accelerations, coordinates, param_names)
self._is_compiled = True
logger.info(f"Compiled {len(coordinates)} coordinates for Numba solver")
def _get_initial_state(self) -> np.ndarray:
"""Get initial state vector from initial conditions."""
state = []
for coord in self.coordinates:
state.append(self.initial_conditions.get(coord, 0.0))
state.append(self.initial_conditions.get(f"{coord}_dot", 0.0))
return np.array(state, dtype=np.float64)
[docs]
def simulate_numba(
self,
t_span: Tuple[float, float],
num_points: int = 1000,
method: str = "rk4",
rtol: float = 1e-6,
atol: float = 1e-9,
) -> Dict:
"""
Run simulation using Numba-accelerated integrator.
Args:
t_span: (t_start, t_end) time interval
num_points: Number of output points
method: Integration method ('euler', 'rk4', 'rk45')
rtol: Relative tolerance (for adaptive methods)
atol: Absolute tolerance (for adaptive methods)
Returns:
Dictionary with 't' and 'y' arrays
"""
if not self._is_compiled:
raise RuntimeError("Equations not compiled. Call compile_equations first.")
t_start, t_end = t_span
y0 = self._get_initial_state()
if self._param_array is None:
self._update_param_array()
params = self._param_array if self._param_array is not None else np.array([])
if method == "rk45":
return self._integrate_adaptive(t_start, t_end, y0, params, num_points, rtol, atol)
else:
return self._integrate_fixed(t_start, t_end, y0, params, num_points, method)
def _integrate_fixed(
self,
t_start: float,
t_end: float,
y0: np.ndarray,
params: np.ndarray,
num_points: int,
method: str,
) -> Dict:
"""Fixed-step integration (Euler or RK4)."""
t_eval = np.linspace(t_start, t_end, num_points)
dt = (t_end - t_start) / (num_points - 1)
y = np.zeros((len(y0), num_points))
y[:, 0] = y0
current_y = y0.copy()
for i in range(1, num_points):
t = t_eval[i - 1]
if method == "euler":
dydt = self._ode_func(t, current_y, params)
current_y = current_y + dt * dydt
else: # rk4
k1 = self._ode_func(t, current_y, params)
k2 = self._ode_func(t + dt / 2, current_y + dt * k1 / 2, params)
k3 = self._ode_func(t + dt / 2, current_y + dt * k2 / 2, params)
k4 = self._ode_func(t + dt, current_y + dt * k3, params)
current_y = current_y + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
y[:, i] = current_y
return {"t": t_eval, "y": y, "success": True}
def _integrate_adaptive(
self,
t_start: float,
t_end: float,
y0: np.ndarray,
params: np.ndarray,
num_points: int,
rtol: float,
atol: float,
) -> Dict:
"""Adaptive RK45 integration with dense output."""
# Use adaptive stepping internally, then interpolate to output points
t_list = [t_start]
y_list = [y0.copy()]
t = t_start
y = y0.copy()
dt = (t_end - t_start) / 100 # Initial step size guess
dt_min = (t_end - t_start) / 1e6
dt_max = (t_end - t_start) / 10
max_steps = num_points * 100
steps = 0
while t < t_end and steps < max_steps:
# Ensure we don't step past t_end
if t + dt > t_end:
dt = t_end - t
# RK45 step
k1 = self._ode_func(t, y, params)
k2 = self._ode_func(t + dt / 5, y + dt * k1 / 5, params)
k3 = self._ode_func(t + 3 * dt / 10, y + dt * (3 * k1 / 40 + 9 * k2 / 40), params)
k4 = self._ode_func(
t + 4 * dt / 5, y + dt * (44 * k1 / 45 - 56 * k2 / 15 + 32 * k3 / 9), params
)
k5 = self._ode_func(
t + 8 * dt / 9,
y
+ dt * (19372 * k1 / 6561 - 25360 * k2 / 2187 + 64448 * k3 / 6561 - 212 * k4 / 729),
params,
)
k6 = self._ode_func(
t + dt,
y
+ dt
* (
9017 * k1 / 3168
- 355 * k2 / 33
+ 46732 * k3 / 5247
+ 49 * k4 / 176
- 5103 * k5 / 18656
),
params,
)
# 5th order solution
y_new = y + dt * (
35 * k1 / 384 + 500 * k3 / 1113 + 125 * k4 / 192 - 2187 * k5 / 6784 + 11 * k6 / 84
)
# Error estimate (difference between 4th and 5th order)
k7 = self._ode_func(t + dt, y_new, params)
error = dt * (
71 * k1 / 57600
- 71 * k3 / 16695
+ 71 * k4 / 1920
- 17253 * k5 / 339200
+ 22 * k6 / 525
- k7 / 40
)
# Error norm
scale = atol + rtol * np.maximum(np.abs(y), np.abs(y_new))
err_norm = np.sqrt(np.mean((error / scale) ** 2))
if err_norm <= 1.0:
# Accept step
t = t + dt
y = y_new
t_list.append(t)
y_list.append(y.copy())
# Update step size
if err_norm == 0:
factor = 5.0
else:
factor = 0.9 * err_norm ** (-0.2)
factor = max(0.2, min(5.0, factor))
dt = np.clip(dt * factor, dt_min, dt_max)
steps += 1
if steps >= max_steps:
logger.warning(f"Max steps ({max_steps}) reached in adaptive integration")
# Interpolate to requested output points
t_internal = np.array(t_list)
y_internal = np.array(y_list).T # Shape: (n_states, n_internal_points)
t_eval = np.linspace(t_start, t_end, num_points)
y_output = np.zeros((y0.shape[0], num_points))
for i in range(y0.shape[0]):
y_output[i, :] = np.interp(t_eval, t_internal, y_internal[i, :])
return {"t": t_eval, "y": y_output, "success": True}
# ============================================================================
# Convenience Functions
# ============================================================================
[docs]
def is_numba_available() -> bool:
"""Check if Numba is available for JIT compilation."""
return HAS_NUMBA
def get_numba_version() -> Optional[str]:
"""Get Numba version if available."""
if HAS_NUMBA:
import numba
return numba.__version__
return None