001 /*
002 * Copyright 2011 Christian Kumpe http://kumpe.de/christian/java
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016 package de.kumpe.hadooptimizer.examples.neurons;
017
018 import java.io.Serializable;
019 import java.util.Collection;
020
021 import de.kumpe.hadooptimizer.EvaluationResult;
022 import de.kumpe.hadooptimizer.Evaluator;
023 import de.kumpe.hadooptimizer.examples.neurons.Neuron.Input;
024 import de.kumpe.hadooptimizer.impl.ReportingHalterWrapper.Reporter;
025
026 public class NeuronalNetEvaluator implements Evaluator<double[]>,
027 Reporter<double[]> {
028 private static final long serialVersionUID = 1L;
029
030 public static class Sample implements Serializable {
031 private static final long serialVersionUID = 1L;
032
033 private final double[] inputValues;
034 private final double outputValue;
035
036 public Sample(final double[] inputValues, final double outputValue) {
037 this.inputValues = inputValues;
038 this.outputValue = outputValue;
039 }
040 }
041
042 final InputValue[] inputValues;
043 final Neuron.Input[] inputs;
044 final Value output;
045 final Sample[] samples;
046
047 public NeuronalNetEvaluator(final InputValue[] inputValues,
048 final Input[] inputs, final Value output, final Sample[] samples) {
049 this.inputValues = inputValues;
050 this.inputs = inputs;
051 this.output = output;
052 this.samples = samples;
053 }
054
055 public void print(final double[] weights) {
056 for (int i = 0; i < inputs.length; i++) {
057 inputs[i].setWeight(weights[i]);
058 }
059
060 for (final Sample sample : samples) {
061 for (int i = 0; i < inputValues.length; i++) {
062 final double inputValue = sample.inputValues[i];
063 if (0 < i) {
064 System.out.print(", ");
065 }
066 System.out.print(inputValue);
067 inputValues[i].set(inputValue);
068 }
069 System.out.printf(" => %f%n", output.value());
070 }
071 }
072
073 @Override
074 public double evaluate(final double[] weights) {
075 double error = 0;
076 for (final Sample sample : samples) {
077 final double diff = evaluate(weights, sample.inputValues)
078 - sample.outputValue;
079 error += diff * diff;
080 }
081 return error;
082 }
083
084 double evaluate(final double[] weights, final double[] inputValues) {
085 for (int i = 0; i < this.inputValues.length; i++) {
086 this.inputValues[i].set(inputValues[i]);
087 }
088
089 for (int i = 0; i < inputs.length; i++) {
090 inputs[i].setWeight(weights[i]);
091 }
092
093 return output.value();
094 }
095
096 @Override
097 public void report(
098 final Collection<EvaluationResult<double[]>> evaluationResults) {
099 final double[] weights = evaluationResults.iterator().next()
100 .getIndividual();
101 print(weights);
102 }
103 }