AeroJAX: JAX-native CFD, differentiable end-to-end. ~560 FPS at 128×128 on CPU [P]

I have been building a JAX based CFD framework for differentiable Navier Stokes simulation inside ML loops such as inverse design and learned closures.

The goal is to keep the full solver stack differentiable so it can sit inside optimisation and learning pipelines.

Design choices:

  • Fully JAX native with no external dependencies
  • CPU first vectorized implementation
  • End to end differentiability through velocity, pressure, and vorticity fields
  • Navier Stokes (projection method) and LBM (D2Q9) support
  • Brinkman style forcing with smooth masks for geometry handling

Currently:

  • 2D incompressible Navier Stokes solver using projection and pressure correction
  • LBM solver integrated into the same framework
  • Performance is CPU bound and grid dependent
    • ~560 FPS at 128x128
    • ~300 FPS at 512x96
  • Differentiable flow fields throughout the pipeline
  • Hooks for neural operators and learned corrections inside the solver loop

Here is the true value:

  • Inverse design where geometry maps to flow and gradients propagate back to geometry
  • Learning turbulence or residual closures directly in the solver
  • Using CFD as a differentiable data generator for ML systems
  • Hybrid physics and learned models without breaking gradient flow

Most CFD and ML pipelines still treat the solver as a black box, which makes gradient based design difficult or impossible.

AeroJAX is an attempt to keep the physics structure intact while making the entire pipeline differentiable.

submitted by /u/LackSome307
[link] [comments]

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top