package org.sc3d.apt.image.v24;

import java.util.Random;

/** An algorithm for minimising an objective function. This is a direct search
 * algorithm, in the sense of "Optimization by Direct Search: New Perspectives
 * on Some Classical and Modern Methods" by Tamara G. Kolda, Robert Michael
 * Lewis and Virginia Torczon ( http://www.cs.wm.edu/~va/research/sirev.pdf ),
 * and can optimise a function without knowing the derivative of that function.
 * It can even optimise functions whose derivatives are not defined at the
 * minimum, or near the path of steepest descent. Unlike most such methods, the
 * algorithm is spherically symmetric, and uses an adaptive metric. This means
 * it shouldn't get stuck in narrow ravines even when they are not aligned with
 * the coordinate axes. It is a stochastic algorithm, in the sense that search
 * directions are chosen at random. */
public class Optimiser {
  /** Constructs an Optimiser for the specified problem.
   * @param f the function to minimise.
   * @param x an initial guess as to where the minimum might be.
   * @param step the distance by which it would be sensible to change 'x' on the
   * first step of the algorithm.
   */
  public Optimiser(ObjectiveFunction f, double[] x, double step) {
    this.f = f; this.x = x;
    this.numVars = f.numVars;
    this.fx = f.applyTo(x);
    this.m = new double[numVars*numVars];
    for (int i=0; i<numVars; i++) this.m[i+numVars*i] = step;
    if (x.length!=this.numVars) throw new IllegalArgumentException(
      "Initial guess has the wrong number of dimensions"
    );
  }
  
  /* New API. */
  
  /** The function that this Optimiser tries to minimise. */
  public final ObjectiveFunction f;
  
  /** The number of input variables to 'f'. This is equal to 'f.numVars'. */
  public final int numVars;
  
  /** The best guess so far as to where the minimum might be. */
  public double[] x;
  
  /** The value of 'f' at 'x'. */
  public double fx;
  
  /** A matrix defining the size of the next step that the algorithm will try. A
   * matrix is needed because the desired step size may be very different in
   * different directions. The locus of possible next steps is an ellipsoid
   * centred on 'x'. The ellipsoid is the same shape and size as the image of a
   * unit sphere under this matrix.
   * <p>
   * This matrix has 'numVars*numVars' elements, but only
   * 'numVars*(numVars+1)/2' degrees of freedom are needed to define an
   * ellipsoid. The representation is therefore redundant. Specifically,
   * premultiplying this matrix by an orthonormal matrix (which has
   * 'numVars*(numVars-1)/2 degrees of freedom) does not change the ellipsoid.
   * <p>
   * The elements of the matrix are listed in left-to-right, top-to-bottom
   * order.
   */
  public double[] m;
  
  /** Attempts to improve 'x'. This takes time O('numVars*numVars') and
   * evaluates 'f' up to twice.
   * <p>
   * This method picks a unit vector at random, and multiplies it by 'm' to form
   * a step vector 'd'. The function 'f' is then evaluated at 'x+d' and at
   * 'x-d'. The behaviour then depends on whether the results are better (less)
   * than 'fx':
   * <ul>
   * <li>If either value gives an improvement on 'fx', 'x' is replaced with the
   * first such value found (and 'fx' is changed accordingly), 'm' is modified
   * so that subsequent steps in the same direction will be longer, and this
   * method returns 'true'.
   * <li>If neither value results in an improvement, then 'x' is not modified,
   * 'm' is modified so that subsequent steps in the same direction will be
   * shorter, and this method returns 'false'.
   * </ul>
   * @return 'true' if 'x' changed and 'f(x)' decreased, or 'false' if the step
   * considered was unsuccessful.
   */
  public boolean tryStep() {
    // Pick a unit vector, in a spherically symmetric way.
    final double[] u = new double[this.numVars];
    double u2 = 0.0;
    for (int i=0; i<u.length; i++) {
      u[i] = RANDOM.nextGaussian();
      u2 += u[i]*u[i];
    }
    final double invSqrtU2 = 1.0 / Math.sqrt(u2);
    for (int i=0; i<u.length; i++) u[i] *= invSqrtU2;
    // Multiply it by 'm'.
    final double[] d = new double[this.numVars];
    for (int i=0; i<d.length; i++) for (int j=0; j<u.length; j++) {
      d[i] += u[j] * this.m[j+this.numVars*i];
    }
    // Evaluate 'f' at 'x+d' and 'x-d', and decide how much to stretch 'm'.
    boolean ans = false;
    for (int dir=-1; dir<=1; dir+=2) {
      final double[] newX = new double[this.numVars];
      for (int i=0; i<newX.length; i++) newX[i] = this.x[i] + dir*d[i];
      final double newFX = this.f.applyTo(newX);
      if (newFX < this.fx) {
        this.x = newX;
        this.fx = newFX;
        ans = true;
        break;
      }
    }
    // Stretch 'm', and return.
    final double stretch = ans ? 1.0 : -0.5;
    for (int i=0; i<d.length; i++) for (int j=0; j<u.length; j++) {
      this.m[j+this.numVars*i] += stretch * d[i] * u[j];
    }
    return ans;
  }
  
  /** Calls 'tryStep()' repeatedly until a step is found that decreases 'fx', or
   * until even very small steps (of size 'minStep') don't give any improvement.
   * More precisely, this method gives up when the sum of the squares of the
   * lengths of the principal axes of the step ellipsoid (the trace of
   * 'm*transpose(m)') is smaller than 'minStep*minStep'.
   * <p>
   * Hopefully the step size will already by appropriate, and a step that
   * decreases 'fx' will be found in the first few trials.
   * @return 'true' if a step is found that successfully decreases 'fx', or
   * 'false' if even very small steps fail.
   */
  public boolean improve(double minStep) {
    do {
      double norm = 0.0;
      for (int i=0; i<this.m.length; i++) norm += this.m[i]*this.m[i];
      if (norm<minStep*minStep) return false;
    } while (!this.tryStep());
    return true;
  }
  
  /** Repeatedly calls 'improve(minStep)' until it returns 'false', i.e. until
   * even very small steps do not improve 'fx'.
   * @return the final value of 'x'.
   */
  public double[] minimise(double minStep) {
    while (this.improve(minStep)) {}
    return this.x;
  }
  
  ////////////////////////////////////////////////////////////////////////////

  /** The abstract superclass of functions suitable for minimisation. */
  public static abstract class ObjectiveFunction {
    /** Constructs an ObjectiveFunction, given the value of 'numVars'. */
    public ObjectiveFunction(int numVars) { this.numVars = numVars; }
    
    /** The number of inputs to this function. */
    public final int numVars;
    
    /** Evaluates this function for the specified inputs.
     * @param x an array of length 'numVars'.
     * @return the value of this function at 'x'.
     */
    public abstract double applyTo(double[] x);
  }
  
  ////////////////////////////////////////////////////////////////////////////
  
  /* Private. */
  
  /** A source of random numbers. */
  private static final Random RANDOM = new Random();
  
  /* Test code. */
  
  public static void main(String[] args) throws java.io.IOException {
    final ObjectiveFunction f = new ObjectiveFunction(2) {
      public double applyTo(double[] x) {
        return Math.abs(Math.cos(x[0]) + 0.5*Math.cos(x[1])) - Math.sin(x[1]);
      }
    };
    final Optimiser me = new Optimiser(f, new double[] {1.0, 0.0}, 0.1);
    final int w = 512, h = 512;
    while (true) {
      final byte[] ps = new byte[w*h];
      for (int y=0; y<h; y++) for (int x=0; x<w; x++) {
        final double[] p = new double[] {x*3.0/w, y*3.0/h};
        ps[x+w*y] = (byte)(100 + 100*f.applyTo(p));
      }
      for (int i=0; i<20; i++) {
        final double a = i*2*Math.PI/20;
        final double c = Math.cos(a), s = Math.sin(a);
        final int x = (int)Math.round(w*(me.x[0] + me.m[0]*c + me.m[1]*s)/3.0);
        final int y = (int)Math.round(w*(me.x[1] + me.m[2]*c + me.m[3]*s)/3.0);
        if (x>=0 && x<w && y>=0 && y<h) ps[x+w*y] = (byte)((i&1)==0 ? 0 : 255);
      }
      new org.sc3d.apt.image.GreyMap(w, h, ps).display("Testing Optimiser");
      me.tryStep();
    }
  }
}
