Optimisation of recursive algorithm in Java

Background

I have an ordered set of data points stored as a `TreeSet<DataPoint>`. Each data point has a `position` and a `Set` of `Event` objects (`HashSet<Event>`).

There are 4 possible `Event` objects `A`, `B`, `C`, and `D`. Every `DataPoint` has 2 of these, e.g. `A` and `C`, except the first and last `DataPoint` objects in the set, which have `T` of size 1.

My algorithm is to find the probability of a new `DataPoint` `Q` at position `x` having `Event` `q` in this set.

I do this by calculating a value `S` for this data set, then adding `Q` to the set and calculating `S` again. I then divide the second `S` by the first to isolate the probability for the new `DataPoint` `Q`.

Algorithm

The formula for calculating `S` is:

http://mathbin.net/equations/105225_0.png

where

http://mathbin.net/equations/105225_1.png

http://mathbin.net/equations/105225_2.png

for

http://mathbin.net/equations/105225_3.png

and

http://mathbin.net/equations/105225_4.png

http://mathbin.net/equations/105225_5.png is an expensive probability function that only depends on its arguments and nothing else (and http://mathbin.net/equations/105225_6.png), http://mathbin.net/equations/105225_7.png is the last `DataPoint` in the set (righthand node), http://mathbin.net/equations/105225_8.png is the first `DataPoint` (lefthand node), http://mathbin.net/equations/105225_9.png is the rightmost `DataPoint` that isn't the node, http://mathbin.net/equations/105225_10.png is a `DataPoint`,http://mathbin.net/equations/105225_12.png is the `Set` of events for this `DataPoint`.

So the probability for `Q` with `Event` `q` is:

http://mathbin.net/equations/105225_11.png

Implementation

I implemented this algorithm in Java like so:

``public class ProbabilityCalculator {private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {// do some stuff}private Double f(DataPoint right, Event rightEvent, NavigableSet<DataPoint> points) {DataPoint left = points.lower(right);Double result = 0.0;if(left.isLefthandNode()) {result = 0.25 * p(right, rightEvent, left, null);} else if(left.isQ()) {result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);} else { // if M_kfor(Event leftEvent : left.getEvents())result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);}return result;}public Double S(NavigableSet<DataPoint> points) {return f(points.last(), points.last().getRightNodeEvent(), points)}}``

So to find the probability of `Q` at `x` with `q`:

``Double S1 = S(points);points.add(Q);Double S2 = S(points);Double probability = S2/S1;``

Problem

As the implementation stands at the moment it follows the mathematical algorithm closely. However this turns out not to be a particularly good idea in practice, as `f` calls itself twice for each `DataPoint`. So for http://mathbin.net/equations/105225_9.png, `f` is called twice, then for the `n-1` `f` is called twice again for each of the previous calls, and so on and so forth. This leads to a complexity of `O(2^n)` which is pretty terrible considering there can be over 1000 `DataPoints` in each `Set`. Because `p()` is independent of everything except its parameters I have included a caching function where if `p()` has already been calculated for these parameters it just returns the previous result, but this doesn't solve the inherent complexity problem. Am I missing something here with regards to repeat computations, or is the complexity unavoidable in this algorithm?

You also need to memoize `f` on the first 2 arguments (the 3rd is always passed through, so you don't need to worry about that). This will reduce the time complexity of your code from O(2^n) to O(n).

UPDATED:

Since as commented below, order can not be used to help optimize another method must be utilized. Since most of the P values will be calculated multiple times (and as noted, this is expensive), one optimization would be to cache them. I am not sure of what the best key would be, but you could imagine changing the code something like:

``````....
private Map<String, Double> previousResultMap = new ....

private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
String key = // calculate unique key from inputs
Double previousResult = previousResultMap.get(key);
if (previousResult != null) {
return previousResult;
}

// do some stuff
previousResultMap.put(key, result);
return result;
}
``````

This approach should effectively reduce a lot of the redundant calculations - however, as you know the data much more than I, you will need to determine the best way to set the key (and even if String is the best representation for that).

Thanks for all your suggestions. I implemented my solution by creating new nested classes for the values of `P` and `F` already calculated, then used a `HashMap` to store the results. The `HashMap` is then queried for the result before computation takes place; if it is present it just returns the result, if it is not it computes the result and adds it to the `HashMap`.

The final product looks a bit like this:

``````public class ProbabilityCalculator {

private NavigableSet<DataPoint> points;

private ProbabilityCalculator(NavigableSet<DataPoint> points) {
this.points = points;
}

private static class P {
public final DataPoint left;
public final Event leftEvent;
public final DataPoint right;
public final Event rightEvent;

public P(DataPoint left, Event leftEvent, DataPoint right, Event rightEvent) {
this.left = left;
this.leftEvent = leftEvent;
this.right = right;
this.rightEvent = rightEvent;
}

public boolean equals(Object o) {
if(!(o instanceof P)) return false;
P p = (P) o;

if(!(this.leftEvent == null ? p.leftEvent == null : this.leftEvent.equals(p.leftEvent)))
return false;
if(!(this.rightEvent == null ? p.rightEvent == null : this.rightEvent.equals(p.rightEvent)))
return false;

return this.left.equals(p.left) && this.right.equals(p.right);
}

public int hashCode() {
int result = 93;

result = 31 * result + this.left.hashCode();
result = 31 * result + this.right.hashCode();
result = this.leftEvent != null ? 31 * result + this.leftEvent.hashCode() : 31 * result;
result = this.rightEvent != null ? 31 * result + this.rightEvent.hashCode() : 31 * result;

return result;
}
}

private Map<P, Double> usedPs = new HashMap<P, Double>();

private static class F {
public final DataPoint left;
public final Event leftEvent;
public final NavigableSet<DataPoint> dataPointsToLeft;

public F(DataPoint dataPoint, Event dataPointEvent, NavigableSet<DataPoint> dataPointsToLeft) {
this.dataPoint = dataPoint;
this.dataPointEvent = dataPointEvent;
this.dataPointsToLeft = dataPointsToLeft;
}

public boolean equals(Object o) {
if(!(o instanceof F)) return false;
F f = (F) o;
return this.dataPoint.equals(f.dataPoint) && this.dataPointEvent.equals(f.dataPointEvent) && this.dataPointsToLeft.equals(f.dataPointsToLeft);
}

public int hashCode() {
int result = 7;

result = 31 * result + this.dataPoint.hashCode();
result = 31 * result + this.dataPointEvent.hashCode();
result = 31 * result + this.dataPointsToLeft.hashCode();

return result;
}

}

private Map<F, Double> usedFs = new HashMap<F, Double>();

private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
P newP = new P(right, rightEvent, left, leftEvent);

if(this.usedPs.containsKey(newP)) return usedPs.get(newP);

// do some stuff

usedPs.put(newP, result);
return result;

}

private Double f(DataPoint right, Event rightEvent) {

F newF = new F(right, rightEvent, dataPointsToLeft);

if(usedFs.containsKey(newF)) return usedFs.get(newF);

DataPoint left = points.lower(right);

Double result = 0.0;

if(left.isLefthandNode()) {
result = 0.25 * p(right, rightEvent, left, null);
} else if(left.isQ()) {
result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
} else { // if M_k
for(Event leftEvent : left.getEvents())
result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
}

usedFs.put(newF, result)

return result;
}

public Double S() {
return f(points.last(), points.last().getRightNodeEvent(), points)
}

public static probabilityOfQ(DataPoint q, NavigableSet<DataPoint> points) {
ProbabilityCalculator pc = new ProbabilityCalculator(points);

Double S1 = S();