public class LinearSGD extends BaseUpdateableClassifier implements UpdateableRegressor, Parameterized, SimpleWeightVectorModel
loss function ℓ(w,x)
used. The solution attempts to minimize
∑i ℓ(w,xi) +
λ0
/2 ||w||22 +
λ1
||w||1, and is
trained by Stochastic Gradient Descent. λ1
to the desired value
divided by the number of unique data points in the whole set will result in
the correct regularization penalty being applied.
See:
Constructor and Description |
---|
LinearSGD()
Creates a new LinearSGD learner for multi-class classification problems.
|
LinearSGD(LinearSGD toClone)
Copy constructor
|
LinearSGD(LossFunc loss,
double eta,
DecayRate decay,
double lambda0,
double lambda1)
Creates a new LinearSGD learner.
|
LinearSGD(LossFunc loss,
double lambda0,
double lambda1)
Creates a new LinearSGD learner
|
Modifier and Type | Method and Description |
---|---|
CategoricalResults |
classify(DataPoint data)
Performs classification on the given data point.
|
LinearSGD |
clone() |
double |
getBias(int index)
Returns the bias term used with the weight vector for the given class
index.
|
double |
getEta()
Returns the current learning rate in use
|
DecayRate |
getEtaDecay()
Returns the decay rate in use
|
GradientUpdater |
getGradientUpdater() |
double |
getLambda0()
Returns the L2 regularization term in use
|
double |
getLambda1()
Returns the L1 regularization term in use
|
LossFunc |
getLoss()
Returns the loss function in use
|
Parameter |
getParameter(String paramName)
Returns the parameter with the given name.
|
List<Parameter> |
getParameters()
Returns the list of parameters that can be altered for this learner.
|
Vec |
getRawWeight(int index)
Returns the raw weight vector associated with the given class index.
|
static Distribution |
guessLambda0(DataSet d)
Guess the distribution to use for the regularization term
λ0 . |
static Distribution |
guessLambda1(DataSet d)
Guess the distribution to use for the regularization term
λ1 . |
boolean |
isUseBias()
Returns whether or not an implicit bias term is in use
|
int |
numWeightsVecs()
Returns the number of weight vectors that can be returned.
|
double |
regress(DataPoint data) |
void |
setEta(double eta)
Sets the initial learning rate η to use.
|
void |
setEtaDecay(DecayRate decay)
Sets the rate at which
η is decayed at each
update. |
void |
setGradientUpdater(GradientUpdater gradientUpdater)
Sets the method that will be used to update the weight vectors given
their gradient information.
|
void |
setLambda0(double lambda0)
λ0 controls the L2 regularization penalty.
|
void |
setLambda1(double lambda1)
λ1 controls the L1 regularization penalty.
|
void |
setLoss(LossFunc loss)
Sets the loss function used for the model.
|
void |
setUp(CategoricalData[] categoricalAttributes,
int numericAttributes)
Prepares the classifier to begin learning from its
UpdateableRegressor.update(jsat.classifiers.DataPoint, double) method. |
void |
setUp(CategoricalData[] categoricalAttributes,
int numericAttributes,
CategoricalData predicting)
Prepares the classifier to begin learning from its
UpdateableClassifier.update(jsat.classifiers.DataPoint, int) method. |
void |
setUseBias(boolean useBias)
Sets whether or not an implicit bias term will be added to the data set
|
boolean |
supportsWeightedData()
Indicates whether the model knows how to train using weighted data points.
|
void |
train(RegressionDataSet dataSet) |
void |
train(RegressionDataSet dataSet,
ExecutorService threadPool) |
void |
update(DataPoint dataPoint,
double targetValue)
Updates the classifier by giving it a new data point to learn from.
|
void |
update(DataPoint dataPoint,
int targetClass)
Updates the classifier by giving it a new data point to learn from.
|
getEpochs, setEpochs, trainC, trainC, trainEpochs
public LinearSGD()
public LinearSGD(LossFunc loss, double lambda0, double lambda1)
loss
- the loss function to uselambda0
- the L2 regularization termlambda1
- the L1 regularization termpublic LinearSGD(LossFunc loss, double eta, DecayRate decay, double lambda0, double lambda1)
loss
- the loss function to useeta
- the initial learning ratedecay
- the decay rate for ηlambda0
- the L2 regularization termlambda1
- the L1 regularization termpublic LinearSGD(LinearSGD toClone)
toClone
- the object to copypublic void setGradientUpdater(GradientUpdater gradientUpdater)
gradientUpdater
- the method to use for updating the weight vectors
from the gradientpublic GradientUpdater getGradientUpdater()
public void setEtaDecay(DecayRate decay)
η
is decayed at each
update.decay
- the decay rate to usepublic DecayRate getEtaDecay()
public void setEta(double eta)
eta
- the learning rate to use.public double getEta()
public void setLoss(LossFunc loss)
loss
- the loss function to usepublic LossFunc getLoss()
public void setLambda0(double lambda0)
lambda0
- the L2 regularization penalty to usepublic double getLambda0()
public void setLambda1(double lambda1)
lambda1
- the L1 regularization penalty to usepublic double getLambda1()
public void setUseBias(boolean useBias)
useBias
- true
to add an implicit bias termpublic boolean isUseBias()
true
if a bias term is in usepublic LinearSGD clone()
clone
in interface Classifier
clone
in interface UpdateableClassifier
clone
in interface Regressor
clone
in interface UpdateableRegressor
clone
in class BaseUpdateableClassifier
public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting)
UpdateableClassifier
UpdateableClassifier.update(jsat.classifiers.DataPoint, int)
method.setUp
in interface UpdateableClassifier
categoricalAttributes
- an array containing the categorical
attributes that will be in each data pointnumericAttributes
- the number of numeric attributes that will be in
each data pointpredicting
- the information for the target class that will be
predictedpublic void setUp(CategoricalData[] categoricalAttributes, int numericAttributes)
UpdateableRegressor
UpdateableRegressor.update(jsat.classifiers.DataPoint, double)
method.setUp
in interface UpdateableRegressor
categoricalAttributes
- an array containing the categorical
attributes that will be in each data pointnumericAttributes
- the number of numeric attributes that will be in
each data pointpublic void update(DataPoint dataPoint, int targetClass)
UpdateableClassifier
update
in interface UpdateableClassifier
dataPoint
- the data point to learntargetClass
- the target class of the data pointpublic void update(DataPoint dataPoint, double targetValue)
UpdateableRegressor
update
in interface UpdateableRegressor
dataPoint
- the data point to learntargetValue
- the target value of the data pointpublic CategoricalResults classify(DataPoint data)
Classifier
classify
in interface Classifier
data
- the data point to classifypublic boolean supportsWeightedData()
Classifier
supportsWeightedData
in interface Classifier
supportsWeightedData
in interface Regressor
public void train(RegressionDataSet dataSet, ExecutorService threadPool)
public void train(RegressionDataSet dataSet)
public List<Parameter> getParameters()
Parameterized
getParameters
in interface Parameterized
public Parameter getParameter(String paramName)
Parameterized
getParameter
in interface Parameterized
paramName
- the name of the parameter to obtainpublic Vec getRawWeight(int index)
SimpleWeightVectorModel
ConstantVector
object may be returned. index = 0
should be usedgetRawWeight
in interface SimpleWeightVectorModel
index
- the class index to get the weight vector forpublic double getBias(int index)
SimpleWeightVectorModel
0
will be returned.index = 0
should be usedgetBias
in interface SimpleWeightVectorModel
index
- the class index to get the weight vector forpublic int numWeightsVecs()
SimpleWeightVectorModel
numWeightsVecs
in interface SimpleWeightVectorModel
SimpleWeightVectorModel.getRawWeight(int)
can be called.public static Distribution guessLambda0(DataSet d)
λ0
.d
- the data set to get the guess forpublic static Distribution guessLambda1(DataSet d)
λ1
.d
- the data set to get the guess forCopyright © 2017. All rights reserved.