Wednesday, November 28, 2012

Fast Exponentiation

Computing the power of a number is quite common. A simple form appears frequently in coding interviews is to compute the power of integers.
Problem: given two integers n and k, compute the number \( n^k \).
One well-known algorithm is to perform exponentiation by repeated squaring. I saw the following pseudo code in Udi Manber's book:
Input: n and k 
Output: P
P := n;
j := k;
while j > 1 do
  P := P*P;
  if j mod 2 = 1 then
    P := P*n;
  j = j div 2;
end
The code is wrong, as can be seen by testing k = 5, n = 2. The above program produces 16, while 32 is the expected result. The simplest implementation of the repeated squaring is to use recursion, which is difficult to be wrong:
long power(int n, int k) {
  assert n > 0 && k > 0;
  if (k == 1) return n;

  if (k & 1 == 1) { // odd
    return n * power(n, k-1);
  }
  else {
    long half = power(n, k/2);
    return half*half;
  }
}
The pseudo code we presented at the beginning is iterative, and hence more efficient. However, as no correct loop invariant is maintained in it, it fails to compute the desired value. A simple intuition to compute the \(n^k\) is to scan the exponent k bit-wise from low to high. Every time we check whether the current bit i is 1, if so, we multiply the corresponding \(2^i\). There are all together \(log_2(k)\) bits.
long power_iterative(int n, int k) {
  assert n > 0 && k > 0;
  int b = k; // high bits to be processed
  long result = 1;
  long base = n;
  while (b != 0) {
    // invariant: result*(base^b) == n^k
    if (b & 1 == 1) {
      result *= base;
    }
    base *= base;
    b >>= 1;
  }
  return result;
}
If the arguments are not guaranteed to be positive, we have to take extra care.
double power(int n, int k) {
  if (n == 0 && k <= 0) 
    throw new RuntimeException("For 0, only positive exponent allowed!");
  if (k == 0) return 1;
  if (n == 0) return 0; 

  int sign = 1; 
  if (n < 0) {
    sign = (k & 1) == 1 ? -1 : 1;
    n = 0 - n;
  }

  long result = sign;
  result *= power_iterative(n, Math.abs(k));
  
  return k > 0 ? result : 1.0/result;
}
One application of fast exponentiation is to compute Fibonacci number. We recall that the Fibonacci number can be expressed in a matrix form: $$ \begin{bmatrix} f(n+1) \\ f(n) \end{bmatrix}= \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\times \begin{bmatrix} f(n) \\ f(n-1) \end{bmatrix} $$ If we denote by A the matrix: \( \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \), then we know $$\begin{bmatrix} f(n+1) \\ f(n) \end{bmatrix} = A^{n}\times\begin{bmatrix} f(1) \\ f(0) \end{bmatrix}$$
The algorithm is in O(log(n)), while an iterative accumulation algorithm takes O(n).
class Matrix{
  long a, b , c, d;
  Matrix(long i , long j, long k, long l) {a=i; b=j; c=k; d=l;}
  Matrix(Matrix m) {a = m.a; b = m.b; c = m.c; d = m.d;}
}
Matrix multiply (Matrix n, Matrix m) {
   return new Matrix(n.a*m.a+n.b*m.c, n.a*m.b+n.b*m.d, n.c*m.a+n.d*m.c, n.c*m.b+n.d*m.d);
}
Matrix fastExponentiation(Matrix m, int k) {
  assert k > 0;
  Matrix base = new Matrix(m);
  Matrix result = new Matrix(1, 0, 0, 1);
  int b = k;
  while (b != 0) {
    if(b & 1 == 1) result = multiply(result, base);
    base = multiply(base, base);
    b >>= 1;
  }
}
long fibonacci(int n) {
  assert n >= 0;
  if (n == 0 || n== 1) return 1;
  Matrix m = new Matrix(1, 1, 1, 0);
  Matrix power = fastExponentiation(m, n);
  return power.c;
}

No comments: