public class RegressionModelEvaluation extends Object
Constructor and Description |
---|
RegressionModelEvaluation(Regressor regressor,
RegressionDataSet dataSet)
Creates a new RegressionModelEvaluation that will perform serial training
|
RegressionModelEvaluation(Regressor regressor,
RegressionDataSet dataSet,
ExecutorService threadpool)
Creates a new RegressionModelEvaluation that will perform parallel training.
|
Modifier and Type | Method and Description |
---|---|
void |
addScorer(RegressionScore scorer)
Adds a new score object that will be used as part of the evaluation when
calling
evaluateCrossValidation(int, java.util.Random) or
evaluateTestSet(jsat.regression.RegressionDataSet) . |
void |
evaluateCrossValidation(int folds)
Performs an evaluation of the regressor using the training data set.
|
void |
evaluateCrossValidation(int folds,
Random rand)
Performs an evaluation of the regressor using the training data set.
|
void |
evaluateCrossValidation(List<RegressionDataSet> lcds)
Performs an evaluation of the regressor using the training data set,
where the folds of the training data set are provided by the user.
|
void |
evaluateCrossValidation(List<RegressionDataSet> lcds,
List<RegressionDataSet> trainCombinations)
Note: Most people should never need to call this method.
|
void |
evaluateTestSet(RegressionDataSet testSet)
Performs an evaluation of the regressor using the initial data set to
train, and testing on the given data set.
|
double |
getErrorStndDev()
Returns the standard deviation of the error from all runs
|
Regressor[] |
getKeptModels()
Returns the models that were kept after the last evaluation.
|
double |
getMaxError()
Returns the maximum squared error observed from all runs.
|
double |
getMeanError()
Returns the mean squared error from all runs.
|
double |
getMinError()
Returns the minimum squared error from all runs.
|
Regressor |
getRegressor()
Returns the regressor that was to be evaluated
|
OnLineStatistics |
getScoreStats(RegressionScore score)
Gets the statistics associated with the given score.
|
long |
getTotalClassificationTime()
Returns the total number of milliseconds spent performing regression on the testing set.
|
long |
getTotalTrainingTime()
Returns the total number of milliseconds spent training the regressor.
|
boolean |
isKeepModels()
This will keep the models trained when evaluating the model.
|
void |
prettyPrintRegressionScores()
Prints out the classification information in a convenient format.
|
void |
setDataTransformProcess(DataTransformProcess dtp)
Sets the data transform process to use when performing cross validation.
|
void |
setKeepModels(boolean keepModels)
Set this to
true in order to keep the trained models after
evaluation. |
void |
setWarmModels(Regressor... warmModels)
Sets the models that will be used for warm starting training.
|
public RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet, ExecutorService threadpool)
regressor
- the regressor model to evaluatedataSet
- the data set to train or perform cross validation fromthreadpool
- the source of threads for training of modelspublic RegressionModelEvaluation(Regressor regressor, RegressionDataSet dataSet)
regressor
- the regressor model to evaluatedataSet
- the data set to train or perform cross validation frompublic void setKeepModels(boolean keepModels)
true
in order to keep the trained models after
evaluation. They can then be retrieved used the getKeptModels()
methods. The default value is false
.keepModels
- true
to keep the trained models after
evaluation, false
to discard them.public boolean isKeepModels()
getKeptModels()
.true
if trained models will be kept after evaluation.public Regressor[] getKeptModels()
null
will be returned instead if isKeepModels()
returns
false
, which is the default.null
if if models are not being kept.public void setWarmModels(Regressor... warmModels)
warmModels
- the models to use for warm start trainingpublic void setDataTransformProcess(DataTransformProcess dtp)
dtp
- the transformation process to clone for use during evaluationpublic void evaluateCrossValidation(int folds)
folds
- the number of folds for cross validationUntrainedModelException
- if the number of folds given is less than 2public void evaluateCrossValidation(int folds, Random rand)
folds
- the number of folds for cross validationrand
- the source of randomness for generating the cross validation setsUntrainedModelException
- if the number of folds given is less than 2public void evaluateCrossValidation(List<RegressionDataSet> lcds)
lcds
- the training data set already split into foldspublic void evaluateCrossValidation(List<RegressionDataSet> lcds, List<RegressionDataSet> trainCombinations)
DataSet.getNumericColumns()
may
get re-used and benefit from its caching)evaluateCrossValidation(java.util.List)
.lcds
- training data set already split into foldstrainCombinations
- each index contains the training data sans the
data stored in the fold associated with that indexpublic void evaluateTestSet(RegressionDataSet testSet)
testSet
- the data set to perform testing onpublic void addScorer(RegressionScore scorer)
evaluateCrossValidation(int, java.util.Random)
or
evaluateTestSet(jsat.regression.RegressionDataSet)
. The
statistics for the given score are reset on every call, and the mean /
standard deviation comes from multiple folds in cross validation. getScoreStats(jsat.regression.evaluation.RegressionScore)
after one of the evaluation methods have been called.scorer
- the score method to keep track of.public OnLineStatistics getScoreStats(RegressionScore score)
null
will be returned. The
object passed in does not need to be the exact same object passed to
addScorer(jsat.regression.evaluation.RegressionScore)
,
it only needs to be equal to the object.score
- the score type to get the result statisticsnull
if the
score is not in th evaluation setpublic void prettyPrintRegressionScores()
addScorer(RegressionScore)
method, nothing will be printed.public double getMinError()
public double getMaxError()
public double getMeanError()
public double getErrorStndDev()
public long getTotalTrainingTime()
public long getTotalClassificationTime()
public Regressor getRegressor()
Copyright © 2017. All rights reserved.