public class KernelSGD extends Object implements UpdateableClassifier, UpdateableRegressor, Parameterized
LinearSGD
, and learns
nonlinear functions via the kernel trick. The implementation is built upon
KernelPoint
and KernelPoints
to support budgeted learning.
Following the LinearSGD implementation, whether or not this algorithm
supports regression, binary-classification, or multi-class classification is
controlled by the loss function
used.
η
/
(λ
* (t + 2 / λ)) , where t
is the time step.Constructor and Description |
---|
KernelSGD()
Creates a new Kernel SGD object for classification with the RBF kernel
|
KernelSGD(KernelSGD toCopy)
Copy constructor
|
KernelSGD(LossFunc loss,
KernelTrick kernel,
double lambda,
KernelPoint.BudgetStrategy budgetStrategy,
int budgetSize)
Creates a new Kernel SGD object
|
KernelSGD(LossFunc loss,
KernelTrick kernel,
double lambda,
KernelPoint.BudgetStrategy budgetStrategy,
int budgetSize,
double eta,
double errorTolerance)
Creates a new Kernel SGD object
|
Modifier and Type | Method and Description |
---|---|
CategoricalResults |
classify(DataPoint data)
Performs classification on the given data point.
|
KernelSGD |
clone() |
int |
getBudgetSize()
Returns the budget size, or maximum number of allowed support vectors.
|
KernelPoint.BudgetStrategy |
getBudgetStrategy()
Returns the method of budget maintenance
|
int |
getEpochs()
Returns the number of epochs to use
|
double |
getErrorTolerance()
Returns the error tolerance that would be used
|
double |
getEta()
Returns the base learning rate
|
KernelTrick |
getKernel()
Returns the kernel in use
|
double |
getLambda()
Returns the L2 regularization parameter
|
LossFunc |
getLoss()
Returns the loss function in use
|
Parameter |
getParameter(String paramName)
Returns the parameter with the given name.
|
List<Parameter> |
getParameters()
Returns the list of parameters that can be altered for this learner.
|
static Distribution |
guessLambda(DataSet d)
Guess the distribution to use for the regularization term
λ . |
double |
regress(DataPoint data) |
void |
setBudgetSize(int budgetSize)
Sets the maximum budget size, or number of support vectors, to allow
during training.
|
void |
setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy)
Sets the budget maintenance strategy.
|
void |
setEpochs(int epochs)
Sets the number of iterations of the training set done during batch
training
|
void |
setErrorTolerance(double errorTolerance)
Sets the error tolerance used for certain
budget strategies |
void |
setEta(double eta)
Sets the base learning rate to start from.
|
void |
setKernel(KernelTrick kernel)
Sets the kernel to use
|
void |
setLambda(double lambda)
Sets the L2 regularization parameter used during learning.
|
void |
setLoss(LossFunc loss)
Sets the loss function to use.
|
void |
setUp(CategoricalData[] categoricalAttributes,
int numericAttributes)
Prepares the classifier to begin learning from its
UpdateableRegressor.update(jsat.classifiers.DataPoint, double) method. |
void |
setUp(CategoricalData[] categoricalAttributes,
int numericAttributes,
CategoricalData predicting)
Prepares the classifier to begin learning from its
UpdateableClassifier.update(jsat.classifiers.DataPoint, int) method. |
boolean |
supportsWeightedData()
Indicates whether the model knows how to train using weighted data points.
|
void |
train(RegressionDataSet dataSet) |
void |
train(RegressionDataSet dataSet,
ExecutorService threadPool) |
void |
trainC(ClassificationDataSet dataSet)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet,
ExecutorService threadPool)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
update(DataPoint dataPoint,
double targetValue)
Updates the classifier by giving it a new data point to learn from.
|
void |
update(DataPoint dataPoint,
int targetClass)
Updates the classifier by giving it a new data point to learn from.
|
public KernelSGD()
public KernelSGD(LossFunc loss, KernelTrick kernel, double lambda, KernelPoint.BudgetStrategy budgetStrategy, int budgetSize)
loss
- the loss function to usekernel
- the kernel trick to uselambda
- the regularization penaltybudgetStrategy
- the budget maintenance strategy to usebudgetSize
- the maximum support vector budgetpublic KernelSGD(LossFunc loss, KernelTrick kernel, double lambda, KernelPoint.BudgetStrategy budgetStrategy, int budgetSize, double eta, double errorTolerance)
loss
- the loss function to usekernel
- the kernel trick to uselambda
- the regularization penaltyeta
- the initial learning ratebudgetStrategy
- the budget maintenance strategy to useerrorTolerance
- the error tolerance used in certain budget maintenance stepsbudgetSize
- the maximum support vector budgetpublic KernelSGD(KernelSGD toCopy)
toCopy
- the object to copypublic void setEpochs(int epochs)
epochs
- the number of iterations in batch trainingpublic int getEpochs()
public void setLoss(LossFunc loss)
loss
- public LossFunc getLoss()
public void setLambda(double lambda)
lambda
- the positive regularization parameterpublic double getLambda()
public void setErrorTolerance(double errorTolerance)
budget strategies
errorTolerance
- the error tolerance in [0, 1]public double getErrorTolerance()
public void setBudgetSize(int budgetSize)
budgetSize
- the maximum allowed number of support vectorspublic int getBudgetSize()
public void setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy)
budgetStrategy
- the method to meet budget size requirementspublic KernelPoint.BudgetStrategy getBudgetStrategy()
public void setEta(double eta)
eta
- the starting learning rate to usepublic double getEta()
public void setKernel(KernelTrick kernel)
kernel
- the kernel to usepublic KernelTrick getKernel()
public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting)
UpdateableClassifier
UpdateableClassifier.update(jsat.classifiers.DataPoint, int)
method.setUp
in interface UpdateableClassifier
categoricalAttributes
- an array containing the categorical
attributes that will be in each data pointnumericAttributes
- the number of numeric attributes that will be in
each data pointpredicting
- the information for the target class that will be
predictedpublic void setUp(CategoricalData[] categoricalAttributes, int numericAttributes)
UpdateableRegressor
UpdateableRegressor.update(jsat.classifiers.DataPoint, double)
method.setUp
in interface UpdateableRegressor
categoricalAttributes
- an array containing the categorical
attributes that will be in each data pointnumericAttributes
- the number of numeric attributes that will be in
each data pointpublic void update(DataPoint dataPoint, int targetClass)
UpdateableClassifier
update
in interface UpdateableClassifier
dataPoint
- the data point to learntargetClass
- the target class of the data pointpublic void update(DataPoint dataPoint, double targetValue)
UpdateableRegressor
update
in interface UpdateableRegressor
dataPoint
- the data point to learntargetValue
- the target value of the data pointpublic CategoricalResults classify(DataPoint data)
Classifier
classify
in interface Classifier
data
- the data point to classifypublic void trainC(ClassificationDataSet dataSet, ExecutorService threadPool)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onthreadPool
- the source of threads to use.public void trainC(ClassificationDataSet dataSet)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onpublic boolean supportsWeightedData()
Classifier
supportsWeightedData
in interface Classifier
supportsWeightedData
in interface Regressor
public void train(RegressionDataSet dataSet, ExecutorService threadPool)
public void train(RegressionDataSet dataSet)
public KernelSGD clone()
clone
in interface Classifier
clone
in interface UpdateableClassifier
clone
in interface Regressor
clone
in interface UpdateableRegressor
clone
in class Object
public List<Parameter> getParameters()
Parameterized
getParameters
in interface Parameterized
public Parameter getParameter(String paramName)
Parameterized
getParameter
in interface Parameterized
paramName
- the name of the parameter to obtainpublic static Distribution guessLambda(DataSet d)
λ
.d
- the data set to get the guess forCopyright © 2017. All rights reserved.