public class LinearBatch extends Object implements Classifier, Regressor, Parameterized, SimpleWeightVectorModel, WarmClassifier, WarmRegressor
loss function ℓ(w,x)
used. The solution attempts to minimize
∑i ℓ(w,xi) +
λ0
/2 ||w||22, and is
trained using a batch optimization method. SimpleWeightVectorModel
interface.
Modifier and Type | Class and Description |
---|---|
class |
LinearBatch.GradFunction
|
class |
LinearBatch.LossFunction
|
class |
LinearBatch.LossMCFunction |
Constructor and Description |
---|
LinearBatch()
Creates a new Linear Batch learner for classification using a small
regularization term
|
LinearBatch(LinearBatch toCopy)
Copy constructor
|
LinearBatch(LossFunc loss,
double lambda0)
Creates a new Linear Batch learner
|
LinearBatch(LossFunc loss,
double lambda0,
double tolerance)
Creates a new Linear Batch learner
|
LinearBatch(LossFunc loss,
double lambda0,
double tolerance,
Optimizer2 optimizer)
Creates a new Linear Batch learner
|
Modifier and Type | Method and Description |
---|---|
CategoricalResults |
classify(DataPoint data)
Performs classification on the given data point.
|
LinearBatch |
clone() |
double |
getBias(int index)
Returns the bias term used with the weight vector for the given class
index.
|
double |
getLambda0()
Returns the L2 regularization term in use
|
LossFunc |
getLoss()
Returns the loss function in use
|
Optimizer2 |
getOptimizer()
Returns the optimization method in use, or
null . |
Parameter |
getParameter(String paramName)
Returns the parameter with the given name.
|
List<Parameter> |
getParameters()
Returns the list of parameters that can be altered for this learner.
|
Vec |
getRawWeight(int index)
Returns the raw weight vector associated with the given class index.
|
double |
getTolerance()
Returns the value of the convergence tolerance parameter
|
static Distribution |
guessLambda0(DataSet d)
Guess the distribution to use for the regularization term
λ0 . |
boolean |
isUseBiasTerm() |
int |
numWeightsVecs()
Returns the number of weight vectors that can be returned.
|
double |
regress(DataPoint data) |
void |
setLambda0(double lambda0)
λ0 controls the L2 regularization penalty.
|
void |
setLoss(LossFunc loss)
Sets the loss function used for the model.
|
void |
setOptimizer(Optimizer2 optimizer)
Sets the method of batch optimization that will be used.
|
void |
setTolerance(double tolerance)
Sets the convergence tolerance to user for training.
|
void |
setUseBiasTerm(boolean useBiasTerm) |
boolean |
supportsWeightedData()
Indicates whether the model knows how to train using weighted data points.
|
void |
train(RegressionDataSet dataSet) |
void |
train(RegressionDataSet D,
ExecutorService threadPool) |
void |
train(RegressionDataSet dataSet,
Regressor warmSolution)
Trains the regressor and constructs a model for regression using the
given data set.
|
void |
train(RegressionDataSet D,
Regressor warmSolution,
ExecutorService threadPool)
Trains the regressor and constructs a model for regression using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet,
Classifier warmSolution)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet D,
Classifier warmSolution,
ExecutorService threadPool)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet D,
ExecutorService threadPool)
Trains the classifier and constructs a model for classification using the
given data set.
|
boolean |
warmFromSameDataOnly()
Some models can only be warm started from a solution trained on the
exact same data set as the model it is warm starting from.
|
public LinearBatch()
public LinearBatch(LossFunc loss, double lambda0)
loss
- the loss function to uselambda0
- the L2 regularization termpublic LinearBatch(LossFunc loss, double lambda0, double tolerance)
loss
- the loss function to uselambda0
- the L2 regularization termtolerance
- the threshold for convergencepublic LinearBatch(LossFunc loss, double lambda0, double tolerance, Optimizer2 optimizer)
loss
- the loss function to uselambda0
- the L2 regularization termtolerance
- the threshold for convergenceoptimizer
- the batch optimization method to usepublic LinearBatch(LinearBatch toCopy)
toCopy
- the object to copypublic void setUseBiasTerm(boolean useBiasTerm)
public boolean isUseBiasTerm()
public void setLambda0(double lambda0)
lambda0
- the L2 regularization penalty to usepublic double getLambda0()
public void setLoss(LossFunc loss)
loss
- the loss function to usepublic LossFunc getLoss()
public void setOptimizer(Optimizer2 optimizer)
null
is
valid for this value, in which case the implementation will attempt to
select a reasonable optimizer automatically. optimizer
- the method to use for function minimizationpublic Optimizer2 getOptimizer()
null
.null
.public void setTolerance(double tolerance)
tolerance
- the convergence tolerancepublic double getTolerance()
public CategoricalResults classify(DataPoint data)
Classifier
classify
in interface Classifier
data
- the data point to classifypublic void trainC(ClassificationDataSet dataSet, Classifier warmSolution)
WarmClassifier
trainC
in interface WarmClassifier
dataSet
- the data set to train onwarmSolution
- the solution to use to warm start this modelpublic void trainC(ClassificationDataSet D, ExecutorService threadPool)
Classifier
trainC
in interface Classifier
D
- the data set to train onthreadPool
- the source of threads to use.public void trainC(ClassificationDataSet D, Classifier warmSolution, ExecutorService threadPool)
WarmClassifier
trainC
in interface WarmClassifier
D
- the data set to train onwarmSolution
- the solution to use to warm start this modelthreadPool
- the source of threads to use.public void trainC(ClassificationDataSet dataSet)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onpublic void train(RegressionDataSet D, ExecutorService threadPool)
public void train(RegressionDataSet dataSet, Regressor warmSolution)
WarmRegressor
train
in interface WarmRegressor
dataSet
- the data set to train onwarmSolution
- the solution to use to warm start this modelpublic void train(RegressionDataSet D, Regressor warmSolution, ExecutorService threadPool)
WarmRegressor
train
in interface WarmRegressor
D
- the data set to train onwarmSolution
- the solution to use to warm start this modelthreadPool
- the source of threads to use.public void train(RegressionDataSet dataSet)
public List<Parameter> getParameters()
Parameterized
getParameters
in interface Parameterized
public Parameter getParameter(String paramName)
Parameterized
getParameter
in interface Parameterized
paramName
- the name of the parameter to obtainpublic boolean warmFromSameDataOnly()
WarmClassifier
true
will be returned. The behavior for training on a
different data set when this is defined is undefined. It may cause an
error, or it may cause the algorithm to take longer or reach a worse
solution. true
, it is important that the data set be unaltered - this
includes mutating the values stored or re-arranging the data points
within the data set.warmFromSameDataOnly
in interface WarmClassifier
warmFromSameDataOnly
in interface WarmRegressor
true
if the algorithm can only be warm started from the
model trained on the exact same data set.public Vec getRawWeight(int index)
SimpleWeightVectorModel
ConstantVector
object may be returned. index = 0
should be usedgetRawWeight
in interface SimpleWeightVectorModel
index
- the class index to get the weight vector forpublic double getBias(int index)
SimpleWeightVectorModel
0
will be returned.index = 0
should be usedgetBias
in interface SimpleWeightVectorModel
index
- the class index to get the weight vector forpublic int numWeightsVecs()
SimpleWeightVectorModel
numWeightsVecs
in interface SimpleWeightVectorModel
SimpleWeightVectorModel.getRawWeight(int)
can be called.public boolean supportsWeightedData()
Classifier
supportsWeightedData
in interface Classifier
supportsWeightedData
in interface Regressor
public LinearBatch clone()
public static Distribution guessLambda0(DataSet d)
λ0
.d
- the data set to get the guess forCopyright © 2017. All rights reserved.