Since my project is winding down fast, I thought I'd just expand more on the math behind the scenes.
The conjugate gradient method is an iterative algorithm that can solve equations of the form Ax = b, or be stopped half-way through to get a good estimate for x. It only works if A is positive-definite and square. It is only practical if A is sparse so multiplications by A are easy to compute.
Lets start with a quick definition:
Def: u and v are conjugate
w.r.t. A if uTAv = 0. Basically, u and Av are orthogonal.
The problem is solving for Ax = b. x can be
broken down into a linear combination of basis of Rn called {dk}
or ‘directions’ that are all conjugate w.r.t to A. Whenever we talk about linear combinations, I like to think of each basis vector as a direction. Bear with me for a second while I glaze over why it is cool if they’re conjugate. Now take the equation:
x = ∑αidi (by linear combination def)
b = Ax = ∑αiAdi (by problem definition)
This isn’t terribly interesting, it is just rewriting Ax as A times a completely arbitrary decomposition of x into {dk}. But imagine we know all the directions in {dk} and we just wanted to solve for the weight α. Normally, this wouldn’t tell us anything: I can break x in R3 into α1[1; 0; 0], α2[0; 1; 0], α3[0; 0; 1], and I still don’t have a good way to find α’s
and x. But if my bases are conjugate, I can use the following equation:
b = Ax = ∑i αiAdi | (see above) |
dkb | (multiply both sides by dk. |
dkb | (by the fact that diAdk |
αk = dkb/dkAdk |
|
So now we have an easy-to-solve equation for all the weights that make Ax = b true. The direct algorithm can be broken into two steps: find conjugate directions w.r.t A, then efficiently
solve for the α’s, and then add them all together to get x. In practice, all these steps can be intermixed in an iterative algorithm that updates x after every iteration.
Conjugate Gradient Algorithm:
for k=1:n
Find direction dk
Find αk
x = x + αk dk
But how do we find dk? One can imagine that there are many different ways to compute a set of conjugate directions, but some may combine to form x in fewer iterations than others. This is where different methods appear such as MINRES, GMRES, and the one I’m concerned with: Conjugate Gradient. In fact, the manner of selecting dk is where the conjugate gradient gets its name! Imagine we have an energy function: E(x) = ½ xTAx - xTb + c, then the gradient (derivative) E’(x) = Ax – b, and we have a minimum for E(x) at 0 = b - Ax or b = Ax (basic calculus). If you were writing an iterative algorithm that finds the minimum of E(x) without solving directly for x, then you might just greedily follow the gradient until you hit the minimum, which is just: b – Ax where x is your current guess for x. Take note that b – Ax is the gradient and the residue in this special case, and the terms get intermixed frequently. If we just followed this greedy algorithm, we’d be doing a steepest descent algorithm. Unfortunately, these bases aren’t orthogonal or conjugate, and this
leads to inefficiency.
We’ll borrow the idea behind steepest descent to pick our first direction, that is, just say x = 0, plug it into b – Ax, and get b as our first derivative. For the rest of the directions, form conjugate bases to the previous bases with the following rules:
d0 = b
dk = rk-1 - ∑i<k ((diArk 1)/(dTiAdi)) di
Basically, dk is the previous derivative orthogonalized (or rather, conjugatilized because we multiply by A) with all di for i<k. This should feel a lot like Gram-Schmidt orthonormalization. Here’s an updated outline of the algorithm:
Conjugate Gradient Algorithm
for k=1:n
Find direction dk
for i=1:k
‘conjugatilize’ dk
with di
Find αk
x = x + αk dk
So as a recap, now we have a good heuristic way of picking directions (dk) such that we can exploit their conjugate property and solve for their weights (αk).
The algorithm listed above captures all the intuition of conjugate gradient, but doesn’t look anything like what is actually implemented. The reason is that the inner for-loop can be optimized away along with the requirement to store all previous bases. I might expand on
this idea in a later post.
This is the more common form:
Conjugate Gradient Algorithm
for k=1:n
rk+1 = rk - αk Adk
βk = (rk+1Trk+1)/(rkTrk)
dk+1 = rk+1 + βkdk
αk+1 = (rkTrk)/(dkTAdk)
xk+1 = xk + αkdk
Gradient Descent
Given an energy function of the form E(x) = ½xTAx - xTb + c, our task is to minimize E(x) without directly solving for x because it would be too expensive.