001    /*
002     * Copyright 2011 Christian Kumpe http://kumpe.de/christian/java
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     *     http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    package de.kumpe.hadooptimizer.examples.neurons;
017    
018    import java.io.Serializable;
019    import java.util.Collection;
020    
021    import de.kumpe.hadooptimizer.EvaluationResult;
022    import de.kumpe.hadooptimizer.Evaluator;
023    import de.kumpe.hadooptimizer.examples.neurons.Neuron.Input;
024    import de.kumpe.hadooptimizer.impl.ReportingHalterWrapper.Reporter;
025    
026    public class NeuronalNetEvaluator implements Evaluator<double[]>,
027                    Reporter<double[]> {
028            private static final long serialVersionUID = 1L;
029    
030            public static class Sample implements Serializable {
031                    private static final long serialVersionUID = 1L;
032    
033                    private final double[] inputValues;
034                    private final double outputValue;
035    
036                    public Sample(final double[] inputValues, final double outputValue) {
037                            this.inputValues = inputValues;
038                            this.outputValue = outputValue;
039                    }
040            }
041    
042            final InputValue[] inputValues;
043            final Neuron.Input[] inputs;
044            final Value output;
045            final Sample[] samples;
046    
047            public NeuronalNetEvaluator(final InputValue[] inputValues,
048                            final Input[] inputs, final Value output, final Sample[] samples) {
049                    this.inputValues = inputValues;
050                    this.inputs = inputs;
051                    this.output = output;
052                    this.samples = samples;
053            }
054    
055            public void print(final double[] weights) {
056                    for (int i = 0; i < inputs.length; i++) {
057                            inputs[i].setWeight(weights[i]);
058                    }
059    
060                    for (final Sample sample : samples) {
061                            for (int i = 0; i < inputValues.length; i++) {
062                                    final double inputValue = sample.inputValues[i];
063                                    if (0 < i) {
064                                            System.out.print(", ");
065                                    }
066                                    System.out.print(inputValue);
067                                    inputValues[i].set(inputValue);
068                            }
069                            System.out.printf(" => %f%n", output.value());
070                    }
071            }
072    
073            @Override
074            public double evaluate(final double[] weights) {
075                    double error = 0;
076                    for (final Sample sample : samples) {
077                            final double diff = evaluate(weights, sample.inputValues)
078                                            - sample.outputValue;
079                            error += diff * diff;
080                    }
081                    return error;
082            }
083    
084            double evaluate(final double[] weights, final double[] inputValues) {
085                    for (int i = 0; i < this.inputValues.length; i++) {
086                            this.inputValues[i].set(inputValues[i]);
087                    }
088    
089                    for (int i = 0; i < inputs.length; i++) {
090                            inputs[i].setWeight(weights[i]);
091                    }
092    
093                    return output.value();
094            }
095    
096            @Override
097            public void report(
098                            final Collection<EvaluationResult<double[]>> evaluationResults) {
099                    final double[] weights = evaluationResults.iterator().next()
100                                    .getIndividual();
101                    print(weights);
102            }
103    }