public class DReDNetSimple extends Object implements Classifier, Parameterized
Constructor and Description |
---|
DReDNetSimple()
Creates a new DRedNet that uses two hidden layers with 1024 neurons each.
|
DReDNetSimple(int... hiddenLayerSizes)
Create a new DReDNet that uses the specified number of hidden layers.
|
Modifier and Type | Method and Description |
---|---|
CategoricalResults |
classify(DataPoint data)
Performs classification on the given data point.
|
DReDNetSimple |
clone() |
int |
getBatchSize() |
int |
getEpochs() |
int[] |
getHiddenSizes() |
Parameter |
getParameter(String paramName)
Returns the parameter with the given name.
|
List<Parameter> |
getParameters()
Returns the list of parameters that can be altered for this learner.
|
void |
setBatchSize(int batchSize)
Sets the batch size for updates
|
void |
setEpochs(int epochs)
Sets the number of epochs to perform
|
void |
setHiddenSizes(int[] hiddenSizes)
Sets the hidden layer sizes for this network.
|
boolean |
supportsWeightedData()
Indicates whether the model knows how to train using weighted data points.
|
void |
trainC(ClassificationDataSet dataSet)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet,
ExecutorService threadPool)
Trains the classifier and constructs a model for classification using the
given data set.
|
public DReDNetSimple()
public DReDNetSimple(int... hiddenLayerSizes)
hiddenLayerSizes
- the length indicates the number of hidden layers,
and the value in each index is the number of neurons in that layerpublic void setHiddenSizes(int[] hiddenSizes)
hiddenSizes
- public int[] getHiddenSizes()
public void setBatchSize(int batchSize)
batchSize
- the number of items to compute the gradient frompublic int getBatchSize()
public void setEpochs(int epochs)
epochs
- the number of training iterations through the whole data
setpublic int getEpochs()
public CategoricalResults classify(DataPoint data)
Classifier
classify
in interface Classifier
data
- the data point to classifypublic void trainC(ClassificationDataSet dataSet, ExecutorService threadPool)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onthreadPool
- the source of threads to use.public void trainC(ClassificationDataSet dataSet)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onpublic boolean supportsWeightedData()
Classifier
supportsWeightedData
in interface Classifier
public DReDNetSimple clone()
clone
in interface Classifier
clone
in class Object
public List<Parameter> getParameters()
Parameterized
getParameters
in interface Parameterized
public Parameter getParameter(String paramName)
Parameterized
getParameter
in interface Parameterized
paramName
- the name of the parameter to obtainCopyright © 2017. All rights reserved.