automatic differentiation



Automatic differentiation (AD) is a technique for computing (transposed) derivatives of functions implemented by computer programs, essentially by applying the chain-rule across program code. It is typically the method of choice for computing derivatives in machine learning and scientific computing because of its efficiency and numerical stability.

Forward and reverse automatic differentiation

AD works by calculating the (transposed) derivative of a composite program in terms of the (transposed) derivatives of the parts, by using the chain-rule. The distinction between derivatives and transposed derivatives leads to the main distinction in automatic differentiation modes:

When calculating a derivative of a program that implements a function f:R nR mf:R^n\to R^m, reverse mode tends to be the more efficient algorithm if nmn \gg m and forward mode tends to be more efficient if nmn\ll m. Seeing that many tasks in machine learning and statistics require the calculation of derivatives (for use in gradient-based optimization or Monte-Carlo sampling) of functions f:R nRf:R^n\to R (e.g. probability density functions) for nn very large, reverse mode AD tends to be the most popular algorithm.

Combinatory Homomorphic Automatic Differentiation (CHAD) - a categorical take on AD

Let us fix some class of categories with stuff SS. We will call its members SS-categories. For example, for our purposes, SS might include properties like

and structure like

The idea behind CHAD will be to view forward and reverse AD as the unique structure (stuff) preserving functor (homomorphism of SS-categories) from the initial SS-category SynSyn to two suitably chosen SS-categories Σ CSynLSyn\Sigma_{CSyn}LSyn and Σ CSynLSyn op\Sigma_{CSyn}LSyn^{op}.

The source language

Consider the initial SS-category SynSyn (put differently, the S-properties category that is freely generated from SS-structure). We can think of this category as a programming language: its objects are types and its morphisms are programs modulo beta-equivalence and eta-equivalence. In fact, for a wide class of programming languages, we can find a suitable choice of stuff SS, such that the programming language arises as the initial SS-category. We will refer to this category as the source language of our AD transformation: its morphisms are the programs we want to differentiate.

For example, if we choose the property part of SS to consist of Cartesian closure and the structure part of SS consists of designated objects RR and ZZ and morphisms sinsin, coscos, (+)(+), and (*)(*), then SynSyn is the simply typed lambda-calculus with base types RR and ZZ and the primitive operations sinsin, coscos, (+)(+), and (*)(*).

The target language and deriving AD

Given a strictly indexed category L:C opCatL:C^{op}\to Cat, we can form its Grothendieck construction Σ CL\Sigma_C L, which is a split fibration over CC. Similarly, we can take the opposite category L op:C opCatL^{op}:C^{op}\to Cat of LL and form Σ CL op\Sigma_C L^{op} from that.

For most natural choices of stuff SS, we can find elegant sufficient conditions on CC and LL that guarantee that Σ CL\Sigma_C L and Σ CL op\Sigma_C L^{op} are both SS-categories. Let us call a strictly indexed category L:C opCatL:C^{op}\to Cat satisfying these conditions a SS'-category (where we think of SS' again as stuff). We give the corresponding SS' for some examples for different choices of SS:

Let us write LSyn:CSyn opCatLSyn:CSyn^{op}\to Cat for the initial SS'-category. We think of this category as the target language of automatic differentiation, in the sense that the forward/reverse derivatives of programs (morphisms) in the source language SynSyn with consist of an associated primal program that is a morphism in CSynCSyn and an associated tangent/cotangent program that is a morphism in LSynLSyn/LSyn opLSyn^{op}. As a programming language, LSynLSyn is a linear dependent type theory over the Cartesian type theory CSynCSyn.

Indeed, as for any SS'-category L:C opCatL:C^{op}\to Cat, Σ CL\Sigma_C L and Σ CL op\Sigma_C L^{op} are SS-categories, it follows that Σ CSynLSyn\Sigma_{CSyn} LSyn and Σ CSynLSyn op\Sigma_{CSyn} LSyn^{op} are, in particular, SS-categories. Seeing that SynSyn is the initial SS-category, we obtain unique morphisms of SS-categories:

Semantics of the source and target languages

The category SetSet of sets and functions gives another example of an SS-category, for a lot of choices of SS (if we choose the sets [[R]][[R]] and functions [[op]][[op]] that we would like to denote with the types RR and operations opop in the structure part of SS). Moreover, the strictly indexed category Fam(CMon):Set opCatFam(CMon):Set^{op}\to Cat of families of commutative monoids tends to give an example of an SS'-category. By initiality of SynSyn and LSyn:CSyn opCatLSyn:CSyn^{op}\to Cat, we obtain

In particular, we see that source language programs tSyn(A,B)t\in Syn(A, B) get interpreted as a function [[t]]Set([[A]],[[B]])[[t]]\in Set([[A]], [[B]]). Similarly, programs tCSyn(A,B)t\in CSyn(A, B) in the target language with Cartesian type get interpreted as a function [[t]]Set([[A]],[[B]])[[t]]\in Set([[A]], [[B]]). Finally, programs tLSyn(A)(B,C)t\in LSyn(A)(B,C) in the target language with linear type get interpreted as a function [[t]]Fam(CMon)([[A]])([[B]],[[C]])[[t]] \in Fam(CMon)([[A]])([[B]],[[C]]): families of monoid homomorphisms.

Correctness of CHAD

We say that CHAD calculates the correct derivative (resp. transposed derivative) of a program ss if the semantics [[Df(s)]][[Df(s)]] of the program Df(s)Df(s) (resp. Dr(s)Dr(s)) equals the pair of the semantics [[s]][[s]] of ss and the derivative T[[s]]T[[s]] (resp. transposed derivative T *[[s]]T^*[[s]]) of the semantics [[s]][[s]] of ss. CHAD is correct in the sense that it calculates the correct (transposed) derivative of any composite (possibly higher-order) program between first-order types (meaning: types built using only positive type formers), provided that it calculates the correct (transposed) derivatives of all primitive operations like (*)(*) that we used to generate the source language. That is, CHAD is a valid way for compositionally calculating (transposed) derivatives of composite computer programs, as long as we correctly implement the derivatives for all primitive operations (basic mathematical functions like multiplication, addition, sine, cosine) in the language.

We can prove this by a standard logical relations argument, relating smooth curves to their primal and (co)tangent curves. Viewed more abstractly, the proof follows automatically because the Artin gluing along a representable functor (like the hom out of the real numbers) of an SS-category is itself again an SS-category, for common most choices of SS.


Forward mode automatic differentiation was introduced by Robert Edwin Wengert in

An early description of reverse mode automatic differentiation can be found in

but it was already described earlier by others such as Seppo Linnainmaa:

A categorical analysis of (non-interleaved) forward and reverse AD is given by

A categorical analysis of (interleaved) forward mode AD for calculating higher order derivatives is given by