Hi everyone!
Parax is a library for "Parametric modeling" in JAX, attempting to bridge the approach between pure JAX PyTrees, and more object-orientated modeling approaches (e.g. using Equinox).
v0.7 has been released, featuring a more polished API as well as some detailed examples in the documentation.
Some of Parax's features:
- Derived/constrained parameters with metadata
- Computed PyTrees and callable parameterizations
- Abstract interfaces for fixed, bounded, and probabilistic PyTrees and parameters
Two new examples in the docs that show off these features
- Bounded optimization (JAXopt)
- Bayesian sampling (BlackJAX)
Perhaps the library is of use to someone, and feel free to leave any feedback!
Cheers,
Gary
[link] [comments]