public class SGDNetworkTrainer extends Object implements Serializable
Constructor and Description |
---|
SGDNetworkTrainer()
Creates a new SGD network training that uses dropout
|
SGDNetworkTrainer(SGDNetworkTrainer toCopy)
Copy constructor
|
Modifier and Type | Method and Description |
---|---|
protected SGDNetworkTrainer |
clone() |
Vec |
feedfoward(Vec x)
Feeds the given singular pattern through the network and computes its
activations
|
void |
finishUpdating()
Calling this method indicates that the user has no intentions of updating
the network again and is ready to use it for prediction.
|
BiastInitializer |
getBiasInit() |
double |
getDropoutHidden() |
double |
getDropoutInput() |
double |
getEta() |
DecayRate |
getEtaDecay() |
GradientUpdater |
getGradientUpdater() |
int[] |
getLayerSizes() |
WeightRegularizer |
getRegularizer() |
WeightInitializer |
getWeightInit() |
void |
setBiasInit(BiastInitializer biasInit)
Sets the method to use when initializing neuron bias values
|
void |
setDropoutHidden(double p)
Sets the probability of dropping a value from the hidden layer
|
void |
setDropoutInput(double p)
Sets the probability of dropping a value from the input layer
|
void |
setEta(double eta)
Sets the base global learning rate.
|
void |
setEtaDecay(DecayRate etaDecay)
Sets the decay rate on the global learning rate over time
|
void |
setGradientUpdater(GradientUpdater updater)
Sets the gradient update that will be used when updating the weight
matrices and bias terms.
|
void |
setLayersActivation(List<ActivationLayer> layersActivation)
Sets the list of layer activations for all layers other than the input
layer.
|
void |
setLayerSizes(int... layerSizes)
Sets the array indicating the total number of layers in the network and
the sizes of each layer.
|
void |
setRegularizer(WeightRegularizer regularizer)
Sets the method of regularizing the connections weights
|
void |
setup()
Prepares the network by creating all needed structure, initializing
weights, and preparing it for updates
|
void |
setWeightInit(WeightInitializer weightInit)
Sets the method used to initialize matrix connection weights
|
double |
updateMiniBatch(List<Vec> x,
List<Vec> y)
Performs a mini-batch update of the network using the given input and
output pairs
|
double |
updateMiniBatch(List<Vec> x,
List<Vec> y,
ExecutorService ex)
Performs a mini-batch update of the network using the given input and
output pairs
|
public SGDNetworkTrainer()
public SGDNetworkTrainer(SGDNetworkTrainer toCopy)
toCopy
- the object to copypublic void setDropoutInput(double p)
p
- the probability in [0, 1) of dropping a value in the input layerpublic double getDropoutInput()
public void setDropoutHidden(double p)
p
- the probability in [0, 1) of dropping a value in the hidden
layerpublic double getDropoutHidden()
public void setEtaDecay(DecayRate etaDecay)
etaDecay
- the decay rate to usepublic DecayRate getEtaDecay()
public void setEta(double eta)
eta
- the learning rate to usepublic double getEta()
public void setRegularizer(WeightRegularizer regularizer)
regularizer
- the method of regularizing the networkpublic WeightRegularizer getRegularizer()
public void setLayerSizes(int... layerSizes)
layerSizes
- the array of layer sizespublic int[] getLayerSizes()
public void setLayersActivation(List<ActivationLayer> layersActivation)
layersActivation
- the list of hidden and output layer activationspublic void setGradientUpdater(GradientUpdater updater)
updater
- the updater to usepublic GradientUpdater getGradientUpdater()
public void setWeightInit(WeightInitializer weightInit)
weightInit
- the weight initialization methodpublic WeightInitializer getWeightInit()
public void setBiasInit(BiastInitializer biasInit)
biasInit
- the bias initialization methodpublic BiastInitializer getBiasInit()
public void setup()
public void finishUpdating()
public double updateMiniBatch(List<Vec> x, List<Vec> y)
x
- the list of input valuesy
- the list of output valuespublic double updateMiniBatch(List<Vec> x, List<Vec> y, ExecutorService ex)
x
- the list of input valuesy
- the list of output valuesex
- the source of threads for parallel computation, may be
null
public Vec feedfoward(Vec x)
x
- the input vector to feed forward through the networkprotected SGDNetworkTrainer clone()
Copyright © 2017. All rights reserved.