|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.neuroph.core.learning.LearningRule
org.neuroph.core.learning.IterativeLearning
org.neuroph.core.learning.SupervisedLearning
public abstract class SupervisedLearning
Base class for all supervised learning algorithms. It extends IterativeLearning, and provides general supervised learning principles.
Field Summary | |
---|---|
protected double |
maxError
Max allowed network error (condition to stop learning) |
protected double[] |
outputError
Stores network output error vector |
protected double |
previousEpochError
Total network error in previous epoch |
protected double |
totalNetworkError
Total network error |
protected double |
totalSquaredErrorSum
Total squared sum of all pattern errors |
Fields inherited from class org.neuroph.core.learning.IterativeLearning |
---|
currentIteration, iterationsLimited, learningRate, maxIterations |
Fields inherited from class org.neuroph.core.learning.LearningRule |
---|
listeners, neuralNetwork |
Constructor Summary | |
---|---|
SupervisedLearning()
Creates new supervised learning rule |
Method Summary | |
---|---|
protected void |
addToSquaredErrorSum(double[] outputError)
Calculates and updates sum of squared errors for single pattern, and updates total sum of squared pattern errors |
protected void |
afterEpoch()
|
protected void |
beforeEpoch()
|
protected double[] |
calculateOutputError(double[] desiredOutput,
double[] output)
Calculates the network error for the current input pattern - diference between desired and actual output |
protected void |
doBatchWeightsUpdate()
This method updates network weights in batch mode - use accumulated weights change stored in Weight.deltaWeight It is executed after each learning epoch, only if learning is done in batch mode. |
void |
doLearningEpoch(DataSet trainingSet)
This method implements basic logic for one learning epoch for the supervised learning algorithms. |
protected boolean |
errorChangeStalled()
Returns true if absolute error change is sufficently small (<=minErrorChange) for minErrorChangeStopIterations number of iterations |
double |
getMaxError()
Returns learning error tolerance - the value of total network error to stop learning. |
double |
getMinErrorChange()
Returns min error change stopping criteria |
int |
getMinErrorChangeIterationsCount()
Returns number of iterations count for for min error change stopping criteria |
int |
getMinErrorChangeIterationsLimit()
Returns number of iterations for min error change stopping criteria |
double |
getPreviousEpochError()
Returns total network error in previous learning epoch |
double |
getTotalNetworkError()
Returns total network error in current learning epoch |
protected boolean |
hasReachedStopCondition()
Returns true if stop condition has been reached, false otherwise. |
boolean |
isInBatchMode()
Returns true if learning is performed in batch mode, false otherwise |
void |
learn(DataSet trainingSet,
double maxError)
Trains network for the specified training set and number of iterations |
void |
learn(DataSet trainingSet,
double maxError,
int maxIterations)
Trains network for the specified training set and number of iterations |
protected void |
learnPattern(DataSetRow trainingElement)
Trains network with the input and desired output pattern from the specified training element |
protected void |
onStart()
This method is executed when learning starts, before the first epoch. |
void |
setBatchMode(boolean batchMode)
Sets batch mode on/off (true/false) |
void |
setMaxError(double maxError)
Sets allowed network error, which indicates when to stopLearning training |
void |
setMinErrorChange(double minErrorChange)
Sets min error change stopping criteria |
void |
setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit)
Sets number of iterations for min error change stopping criteria |
protected abstract void |
updateNetworkWeights(double[] outputError)
This method should implement the weights update procedure for the whole network for the given output error vector. |
Methods inherited from class org.neuroph.core.learning.IterativeLearning |
---|
doOneLearningIteration, getCurrentIteration, getLearningRate, isPausedLearning, learn, learn, pause, resume, setLearningRate, setMaxIterations |
Methods inherited from class org.neuroph.core.learning.LearningRule |
---|
addListener, fireLearningEvent, getNeuralNetwork, getTrainingSet, isStopped, removeListener, setNeuralNetwork, setTrainingSet, stopLearning |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
protected transient double totalNetworkError
protected transient double totalSquaredErrorSum
protected transient double previousEpochError
protected double maxError
protected double[] outputError
Constructor Detail |
---|
public SupervisedLearning()
Method Detail |
---|
public void learn(DataSet trainingSet, double maxError)
trainingSet
- training set to learnmaxError
- maximum number of iterations to learnpublic void learn(DataSet trainingSet, double maxError, int maxIterations)
trainingSet
- training set to learnmaxIterations
- maximum number of learning iterationsprotected void onStart()
IterativeLearning
onStart
in class IterativeLearning
protected void beforeEpoch()
beforeEpoch
in class IterativeLearning
protected void afterEpoch()
afterEpoch
in class IterativeLearning
public void doLearningEpoch(DataSet trainingSet)
doLearningEpoch
in class IterativeLearning
trainingSet
- training set for training networkprotected void learnPattern(DataSetRow trainingElement)
trainingElement
- supervised training element which contains input and desired
outputprotected void doBatchWeightsUpdate()
SupervisedLearning#doLearningEpoch(org.neuroph.core.learning.TrainingSet)
protected boolean hasReachedStopCondition()
protected boolean errorChangeStalled()
protected double[] calculateOutputError(double[] desiredOutput, double[] output)
output
- actual network outputdesiredOutput
- desired network output
public boolean isInBatchMode()
public void setBatchMode(boolean batchMode)
batchMode
- batch mode settingpublic void setMaxError(double maxError)
maxError
- network errorpublic double getMaxError()
public double getTotalNetworkError()
public double getPreviousEpochError()
public double getMinErrorChange()
public void setMinErrorChange(double minErrorChange)
minErrorChange
- value for min error change stopping criteriapublic int getMinErrorChangeIterationsLimit()
public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit)
minErrorChangeIterationsLimit
- number of iterations for min error change stopping criteriapublic int getMinErrorChangeIterationsCount()
protected void addToSquaredErrorSum(double[] outputError)
outputError
- output error vectorprotected abstract void updateNetworkWeights(double[] outputError)
outputError
- output error vector for some network input (aka. patternError, network error)
usually the difference between desired and actual outputcalculateOutputError
,
addToSquaredErrorSum(double[])
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |