001 /* 002 * Copyright 2011 Christian Kumpe http://kumpe.de/christian/java 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 package de.kumpe.hadooptimizer.examples.neurons; 017 018 import java.io.Serializable; 019 import java.util.Collection; 020 021 import de.kumpe.hadooptimizer.EvaluationResult; 022 import de.kumpe.hadooptimizer.Evaluator; 023 import de.kumpe.hadooptimizer.examples.neurons.Neuron.Input; 024 import de.kumpe.hadooptimizer.impl.ReportingHalterWrapper.Reporter; 025 026 public class NeuronalNetEvaluator implements Evaluator<double[]>, 027 Reporter<double[]> { 028 private static final long serialVersionUID = 1L; 029 030 public static class Sample implements Serializable { 031 private static final long serialVersionUID = 1L; 032 033 private final double[] inputValues; 034 private final double outputValue; 035 036 public Sample(final double[] inputValues, final double outputValue) { 037 this.inputValues = inputValues; 038 this.outputValue = outputValue; 039 } 040 } 041 042 final InputValue[] inputValues; 043 final Neuron.Input[] inputs; 044 final Value output; 045 final Sample[] samples; 046 047 public NeuronalNetEvaluator(final InputValue[] inputValues, 048 final Input[] inputs, final Value output, final Sample[] samples) { 049 this.inputValues = inputValues; 050 this.inputs = inputs; 051 this.output = output; 052 this.samples = samples; 053 } 054 055 public void print(final double[] weights) { 056 for (int i = 0; i < inputs.length; i++) { 057 inputs[i].setWeight(weights[i]); 058 } 059 060 for (final Sample sample : samples) { 061 for (int i = 0; i < inputValues.length; i++) { 062 final double inputValue = sample.inputValues[i]; 063 if (0 < i) { 064 System.out.print(", "); 065 } 066 System.out.print(inputValue); 067 inputValues[i].set(inputValue); 068 } 069 System.out.printf(" => %f%n", output.value()); 070 } 071 } 072 073 @Override 074 public double evaluate(final double[] weights) { 075 double error = 0; 076 for (final Sample sample : samples) { 077 final double diff = evaluate(weights, sample.inputValues) 078 - sample.outputValue; 079 error += diff * diff; 080 } 081 return error; 082 } 083 084 double evaluate(final double[] weights, final double[] inputValues) { 085 for (int i = 0; i < this.inputValues.length; i++) { 086 this.inputValues[i].set(inputValues[i]); 087 } 088 089 for (int i = 0; i < inputs.length; i++) { 090 inputs[i].setWeight(weights[i]); 091 } 092 093 return output.value(); 094 } 095 096 @Override 097 public void report( 098 final Collection<EvaluationResult<double[]>> evaluationResults) { 099 final double[] weights = evaluationResults.iterator().next() 100 .getIndividual(); 101 print(weights); 102 } 103 }