ยง Mutorch

Minimal reverse mode AD implementation.
#!/usr/bin/env python3

# x2, w1, w2 are Leaf variables
# x1 = f(w1, w2)
# y = g(x1, x2)
# loss = h(y)


# BACKPROP [reverse mode AD]
# ==========================
# t is a hallucinated variable.
# y = f(x)
# GIVEN: dt/dy
# TO FIND: dt/dx
# dt/dx = dt/dy * dy/dx
# dt/dloss
# t = loss
# dt/dloss = dloss/dloss = 1


# y1 = f(x1, x2, x3)
# y2 = g(x1, x2, x3)

# FORWARD MODE: [Tangent space] ---- objects of the does nothing at all :$\texttt{form (partial f/partial x)
# total gradient of x1: df/dx1 + dg/dx1
# total gradient of x2: df/dx2 + dg/dx2
# total gradient of x3: df/dx3 + dg/dx3

# l = r cos(theta)
# dl = dr cos(theta) + rsin(theta) dtheta
# dl/dtheta = dr/dtheta cos(theta) + rsin(theta) dtheta/dtheta
# dl/dtheta =   0       * .......  + rsin(theta) * 1

# dl/dr = dr/dr cos(theta) + rsin(theta) dtheta/dr
# dl/dr = cos(theta) +      .............*0

# REVERSE MODE: [CoTangent space] --- objects of the form df
# total gradient of y1: dy1 = (df/dx1)dx1 + (df/dx2)dx2  + (df/dx3)dx3
# total gradient of y2: dy2 = (dg/dx1)dx1 + (dg/dx2)dx2  + (dg/dx3)dx3
# HALLUCINATED T:
#    y1 = f(x1, x2, x3)
#    GIVEN:   dt/dy1 [output]
#    TO FIND: dt/dx1, dt/dx2, dt/dx3 [inputs]
#    SOLN:    dt/dxi = dt/dy * dy/dxi
#                    = dt/dy * df/dxi
import pudb

class Expr:
    def __mul__(self, other):
        return Mul(self, other)
    def __add__(self, other):
        return Add(self, other)

    def clear_grad(self):
        pass

class Var(Expr):
    def __init__(self, name, val):
        self.name = name
        self.val = val
        self._grad = 0
    def __str__(self):
        return "(var-%s | %s)" % (self.name, self.val)
    def __repr__(self):
        return self.__str__()

    def clear_grad(self):
        self._grad = 0

    def backprop(self, dt_doutput):
        self._grad += dt_doutput

    def grad(self):
        return self._grad

class Mul(Expr):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs
        self.val = self.lhs.val * self.rhs.val
    def __str__(self):
        return "(* %s %s | %s)" % (self.lhs, self.rhs, self.val)
    def __repr__(self):
        return self.__str__()

    #         -------- input1
    #   S    /
    #  ---> v
    #  <--output *
    #      ^
    #       \_________ input2
    # think in terms of sensitivity.
    # - output has S sensitivity to something,
    # - output = input1 + input2
    # - how much sensitivity does input1 have to S?
    # - the same (S), because "sensitivity" is linear [a conjecture/axiom]
    # output = f(input1, input2); f(input1, input2) = input1 + input2
    def backprop(self, dt_output):
        # dt/dinput1 = dt/doutput * ddoutput/dinput1 =
        #            = dt/doutput * d(f(input1, input2))/dinput1
        #            = dt/doutput * d(input1 * input2)/dinput1
        #            = dt/doutput * input2
        self.lhs.backprop(dt_output * self.rhs.val)
        self.rhs.backprop(dt_output * self.lhs.val)

# a = ...   ^
# b = ...   ^
# c = a + b ^
#
class Add(Expr):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs
        self.val = self.lhs.val + self.rhs.val
    def __str__(self):
        return "(+ %s %s | %s)" % (self.lhs, self.rhs, self.val)
    def __repr__(self):
        return self.__str__()

    #         -------- input1
    #   S    /
    #  ---> v
    #  <--output
    #      ^
    #       \_________ input2
    # think in terms of sensitivity.
    # - output has S sensitivity to something,
    # - output = input1 + input2
    # - how much sensitivity does input1 have to S?
    # - the same (S), because "sensitivity" is linear [a conjecture/axiom]
    # output = f(input1, input2); f(input1, input2) = input1 + input2
    def backprop(self, dt_output):
        # dt/dinput1 = dt/doutput * ddoutput/dinput1 =
        #            = dt/doutput * d(f(input1, input2))/dinput1
        #            = dt/doutput * d(input1 + input2)/dinput1
        #            = dt/doutput * 1
        self.lhs.backprop(dt_output * 1)
        self.rhs.backprop(dt_output * 1)

class Max(Expr):
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs
        self.val = max(self.lhs.val, self.rhs.val)
    def __str__(self):
        return "(max %s %s | %s)" % (self.lhs, self.rhs, self.val)
    def __repr__(self):
        return self.__str__()

    def backprop(self, dt_output):
        # dt/dinput1 = dt/doutput * doutput/dinput 1
        #            = dt/doutput *d max(input1, input2)/dinput1
        #            = |dt/doutput *d input1/dinput1 [if input1 > input2] = 1
        #            = |dt/doutput *d input2/dinput1 [if input2 > input1] = 0
        if self.val == self.lhs.val:
            self.lhs.backprop(dt_output * 1)
        else:
            self.rhs.backprop(dt_output * 1)

x = Var("x", 10)
print("x: %s" % x)
y = Var("y", 20)
p = Var("p", 30)
print("y: %s" % y)
z0 = Mul(x, x)
print("z0: %s" % z0)
z1 = Add(z0, y)
print("z1: %s" % z1)

# z1 = x*x+y
# dz1/dx = 2x
# dz1/dy = 1
# dz1/dp = 0
# z1.clear_grad()
z1.backprop(1) #t = z1
print("dz/dx: %s" % x.grad())
print("dz/dy: %s" % y.grad())
print("dz/dp: %s" % p.grad())

x.clear_grad()
y.clear_grad()
z1.backprop(1) #t = z1
print("dz/dx: %s" % x.grad())
print("dz/dy: %s" % y.grad())