§ Comparison of forward and reverse mode AD
Quite a lot of ink has been spilt on this topic. My favourite reference
is the one by Rufflewind .
However, none of these examples have a good stock of examples for the diference.
So here, I catalogue the explicit computations between computing forward
mode AD and reverse mode AD.
In general, in forward mode AD, we fix how much the inputs wiggle with
respect to a parameter t. We figure out how much the output wiggles
with respect to t. If output=f(input1,input2,…inputn),
then ∂t∂output=∑i∂inputi∂f∂dt∂inputi.
In reverse mode AD, we fix how much the parameter t wiggles with
respect to the output. We figure out how much the parameter t
wiggles with respect to the inputs.
If outputi=fi(input,…), then ∂input∂t=∑i∂outputi∂tinput∂fi.
This is a much messier expression, since we need to accumulate the data
over all outputs.
Essentially, deriving output from input is easy, since how to compute an output
from an input is documented in one place. deriving input from output is
annoying, since many outputs can depent on a single output.
The upshot is that if we have few "root outputs" (like a loss function),
we need to run AD once with respect to this, and we will get the wiggles
of all inputs at the same time with respect to this output, since we
compute the wiggles output to input.
The first example of z = max(x, y)
captures the essential difference
between the two approached succinctly. Study this, and everything else will make
sense.
§ Maximum: z = max(x, y)
z∂t∂x∂t∂y∂t∂z=max(x,y)=?=?={∂t∂x∂t∂yif x>yotherwise
We can compute ∂x∂z by setting t=x.
That is, ∂t∂x=1,∂t∂y=0.
Similarly, can compute ∂y∂z by setting t=y.
That is, ∂t∂x=1,∂t∂y=0.
If we want both gradients ∂x∂z,∂y∂z,
we will have to rerun the above equations twice with the two initializations.
In our equations, we are saying that we know how sensitive
the inputs x,y are to a given parameter t. We are deriving how sensitive
the output z is to the parameter t as a composition of x,y. If
x>y, then we know that z is as sensitive to t as x is.
z∂z∂t∂x∂t∂y∂t=max(x,y)=?={∂z∂t0ifx>yotherwise={∂z∂t0ify>xotherwise
We can compute ∂x∂z,∂y∂z
in one shot by setting t=z. That is, ∂t∂z=1.
In our equations, we are saying that we know how sensitive
the parameter t is to a given output z. We are trying to see
how sensitive t is to the inputs x,y. If x is active (ie, x>y),
then t is indeed sensitive to x and ∂x∂t=1.
Otherwise, it is not sensitive, and ∂x∂t=0.
§ sin: z = sin(x)
z∂t∂x∂t∂z=sin(x)=?=∂x∂z∂t∂x=cos(x)∂t∂x
We can compute ∂x∂z by setting t=x.
That is, setting ∂t∂x=1.
z∂z∂t∂x∂t=sin(x)=?=∂z∂t∂x∂z=∂z∂tcos(x)
We can compute ∂x∂z by setting t=z.
That is, setting ∂t∂z=1.
§ addition: z = x + y
:
z∂t∂x∂t∂y∂t∂z=x+y=?=?=∂x∂z∂t∂x+∂y∂z∂t∂y=1⋅∂t∂x+1⋅∂t∂y=∂t∂x+∂t∂y
z∂z∂t∂x∂t∂y∂t=x+y=?=∂z∂t∂x∂z=∂z∂t⋅1=∂z∂t=∂z∂t∂y∂z=∂z∂t⋅1=∂z∂t
§ multiplication: z = xy
z∂t∂x∂t∂y∂t∂z=xy=?=?=∂x∂z∂t∂x+∂y∂z∂t∂y=y∂t∂x+x∂t∂y
z∂z∂t∂x∂t∂y∂t=xy=?=∂z∂t∂x∂z=∂z∂t⋅y=∂z∂t∂y∂z=∂z∂t⋅x
§ subtraction: z = x - y
:
z∂t∂x∂t∂y∂t∂z=x+y=?=?=∂x∂z∂t∂x−∂y∂z∂t∂y=1⋅∂t∂x−1⋅∂t∂y=∂t∂x−∂t∂y
z∂z∂t∂x∂t∂y∂t=x−y=?=∂z∂t∂x∂z=∂z∂t⋅1=∂z∂t=∂z∂t∂y∂z=∂z∂t⋅−1=−∂z∂t