Papers
arxiv:2311.18727

Automatic Functional Differentiation in JAX

Published on Nov 30, 2023
Authors:

Abstract

We extend JAX with the capability to automatically differentiate higher-order functions (functionals and operators). By representing functions as a generalization of arrays, we seamlessly use JAX's existing primitive system to implement higher-order functions. We present a set of primitive <PRE_TAG>operators</POST_TAG> that serve as foundational building blocks for constructing several key types of functionals. For every introduced primitive operator, we derive and implement both linearization and transposition rules, aligning with JAX's internal protocols for forward and reverse mode automatic differentiation. This enhancement allows for functional differentiation in the same syntax traditionally use for functions. The resulting functional gradients are themselves functions ready to be invoked in python. We showcase this tool's efficacy and simplicity through applications where functional derivatives are indispensable. The source code of this work is released at https://github.com/sail-sg/autofd .

Community

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2311.18727 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2311.18727 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2311.18727 in a Space README.md to link it from this page.

Collections including this paper 0

No Collection including this paper

Add this paper to a collection to link it from this page.