JAX-first unscented Kalman filtering with differentiable log-likelihoods for gradient-based parameter estimation.
- UKF core: Merwe-scaled sigma points, JIT-compiled predict/update steps, and optional batchable measurement models.
- Smoothing + EM hooks: Rauch–Tung–Striebel smoother plus EM-friendly statistics (
Q_st,P_k_k1) for refining process noise and priors. - Dynamic systems:
DynamicSystembase with RK4 integration, Jacobians, and linearised transition matrices for mixed continuous/discrete models. - Observation management: Time-windowed observation interfaces, parameter slicing, and aligned multi-sensor batching via
ObservationManager. - Process noise shaping: Continuous-to-discrete noise integration (
process_white_noise) and parameterised covariance builders (ParamUnpacker.Qc). - Results + exports: Structured forward/smoother outputs with log-likelihoods, CSV/pickle exporters, and convenience fusion (
bayesian_inverse_variance_weighting).
Requires Python 3.12+. From the repo root:
uv sync # create .venv and install locked deps
uv run pip install -e . # editable install for local changesWithout uv:
pip install -e .- Implement a
DynamicSystemsubclass ingradse.dynsyswithx_idxandode. - Implement one or more
Observationsubclasses ingradse.observation(definet_start,t_end,n_theta,theta_init,theta_cov,y,hx,dhx,R,in_range). - Align measurements with
ObservationManager().construct_steps(obs, t_end, dt_max)and useom.hx/om.dhxin filters. - Instantiate
UnscentedKalmanFilter(dsys, dt_int_max=...)and set priors/parameter handling withParamUnpacker(sys, x0_prior, P0_prior, obs, q_source_init). - Per step: build
Q = process_white_noise(sys.jac(x), pu.Qc(theta), dt, density), callpredict, thenupdate(sethx_batch_enabled=Falseif passing a single-statehx), and store into aForwardResult. - Optional: smooth with
ukf.rts_smoother(forward.x_post, forward.P_post, forward.Q, forward.delta_t)and differentiate log-likelihoods viajax.grad/jax.value_and_gradfor parameter tuning.