public class ClassificationModelEvaluation extends Object
Constructor and Description |
---|
ClassificationModelEvaluation(Classifier classifier,
ClassificationDataSet dataSet)
Constructs a new object that can perform evaluations on the model.
|
ClassificationModelEvaluation(Classifier classifier,
ClassificationDataSet dataSet,
ExecutorService threadpool)
Constructs a new object that can perform evaluations on the model.
|
Modifier and Type | Method and Description |
---|---|
void |
addScorer(ClassificationScore scorer)
Adds a new score object that will be used as part of the evaluation when
calling
evaluateCrossValidation(int, java.util.Random) or
evaluateTestSet(jsat.classifiers.ClassificationDataSet) . |
boolean |
doseStoreResults() |
void |
evaluateCrossValidation(int folds)
Performs an evaluation of the classifier using the training data set.
|
void |
evaluateCrossValidation(int folds,
Random rand)
Performs an evaluation of the classifier using the training data set.
|
void |
evaluateCrossValidation(List<ClassificationDataSet> lcds)
Performs an evaluation of the classifier using the training data set,
where the folds of the training data set are provided by the user.
|
void |
evaluateCrossValidation(List<ClassificationDataSet> lcds,
List<ClassificationDataSet> trainCombinations)
Note: Most people should never need to call this method.
|
void |
evaluateTestSet(ClassificationDataSet testSet)
Performs an evaluation of the classifier using the initial data set to train, and testing on the given data set.
|
Classifier |
getClassifier()
Returns the classifier that was original given for evaluation.
|
double[][] |
getConfusionMatrix() |
double |
getCorrectWeights()
Returns the total value of the weights for data points that were classified correctly.
|
double |
getErrorRate()
Computes the weighted error rate of the classifier.
|
OnLineStatistics |
getErrorRateStats()
Returns the object that keeps track of the error on
individual evaluations.
|
Classifier[] |
getKeptModels()
Returns the models that were kept after the last evaluation.
|
double[] |
getPointWeights()
If
keepPredictions(boolean) was set, this method will return
the array storing the weights for each of the points that were classified |
CategoricalResults[] |
getPredictions()
If
keepPredictions(boolean) was set, this method will return
the array storing the predictions made by the classifier during
evaluation. |
OnLineStatistics |
getScoreStats(ClassificationScore score)
Gets the statistics associated with the given score.
|
double |
getSumOfWeights()
Returns the total value of the weights for all data points that were tested against
|
long |
getTotalClassificationTime()
Returns the total number of milliseconds spent performing classification on the testing set.
|
long |
getTotalTrainingTime()
Returns the total number of milliseconds spent training the classifier.
|
int[] |
getTruths()
If
keepPredictions(boolean) was set, this method will return
the array storing the target classes that should have been predicted
during evaluation. |
boolean |
isKeepModels()
This will keep the models trained when evaluating the model.
|
void |
keepPredictions(boolean keepPredictions)
Indicates whether or not the predictions made during evaluation should be
stored with the expected value for retrieval later.
|
void |
prettyPrintClassificationScores()
Prints out the classification information in a convenient format.
|
void |
prettyPrintConfusionMatrix()
Assuming that we are on the start of a new line, the confusion matrix will be pretty printed to
System.out |
void |
setDataTransformProcess(DataTransformProcess dtp)
Sets the data transform process to use when performing cross validation.
|
void |
setKeepModels(boolean keepModels)
Set this to
true in order to keep the trained models after
evaluation. |
void |
setWarmModels(Classifier... warmModels)
Sets the models that will be used for warm starting training.
|
public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet dataSet)
classifier
- the model to train and evaluatedataSet
- the training data set.public ClassificationModelEvaluation(Classifier classifier, ClassificationDataSet dataSet, ExecutorService threadpool)
classifier
- the model to train and evaluatedataSet
- the training data set.threadpool
- the source of threads for parallel training.
If set to null, training will be done using the
Classifier.trainC(jsat.classifiers.ClassificationDataSet)
method.public void setKeepModels(boolean keepModels)
true
in order to keep the trained models after
evaluation. They can then be retrieved used the getKeptModels()
methods. The default value is false
.keepModels
- true
to keep the trained models after
evaluation, false
to discard them.public boolean isKeepModels()
getKeptModels()
.true
if trained models will be kept after evaluation.public Classifier[] getKeptModels()
null
will be returned instead if isKeepModels()
returns
false
, which is the default.null
if if models are not being kept.public void setWarmModels(Classifier... warmModels)
warmModels
- the models to use for warm start trainingpublic void setDataTransformProcess(DataTransformProcess dtp)
dtp
- the transformation process to clone for use during evaluationpublic void evaluateCrossValidation(int folds)
folds
- the number of folds for cross validationUntrainedModelException
- if the number of folds given is less than 2public void evaluateCrossValidation(int folds, Random rand)
folds
- the number of folds for cross validationrand
- the source of randomness for generating the cross validation setsUntrainedModelException
- if the number of folds given is less than 2public void evaluateCrossValidation(List<ClassificationDataSet> lcds)
lcds
- the training data set already split into foldspublic void evaluateCrossValidation(List<ClassificationDataSet> lcds, List<ClassificationDataSet> trainCombinations)
DataSet.getNumericColumns()
may
get re-used and benefit from its caching)evaluateCrossValidation(java.util.List)
.lcds
- training data set already split into foldstrainCombinations
- each index contains the training data sans the
data stored in the fold associated with that indexpublic void evaluateTestSet(ClassificationDataSet testSet)
testSet
- the data set to perform testing onpublic void addScorer(ClassificationScore scorer)
evaluateCrossValidation(int, java.util.Random)
or
evaluateTestSet(jsat.classifiers.ClassificationDataSet)
. The
statistics for the given score are reset on every call, and the mean /
standard deviation comes from multiple folds in cross validation. getScoreStats(ClassificationScore)
after one of the evaluation methods have been called.scorer
- the score method to keep track of.public OnLineStatistics getScoreStats(ClassificationScore score)
null
will be returned. The
object passed in does not need to be the exact same object passed to
addScorer(ClassificationScore)
,
it only needs to be equal to the object.score
- the score type to get the result statisticsnull
if the
score is not in th evaluation setpublic void keepPredictions(boolean keepPredictions)
keepPredictions
- true if space should be allocated to
store the predictions madepublic boolean doseStoreResults()
public CategoricalResults[] getPredictions()
keepPredictions(boolean)
was set, this method will return
the array storing the predictions made by the classifier during
evaluation. These results may not be in the same order as the data set
they came from, but the order is paired with getTruths()
public int[] getTruths()
keepPredictions(boolean)
was set, this method will return
the array storing the target classes that should have been predicted
during evaluation. These results may not be in the same order as the data
set they came from, but the order is paired with getPredictions()
public double[] getPointWeights()
keepPredictions(boolean)
was set, this method will return
the array storing the weights for each of the points that were classifiedpublic double[][] getConfusionMatrix()
public void prettyPrintConfusionMatrix()
System.out
public void prettyPrintClassificationScores()
addScorer(ClassificationScore)
method, nothing will be printed.public double getCorrectWeights()
public double getSumOfWeights()
public double getErrorRate()
public OnLineStatistics getErrorRateStats()
evaluateTestSet(jsat.classifiers.ClassificationDataSet)
was called.public long getTotalTrainingTime()
public long getTotalClassificationTime()
public Classifier getClassifier()
Copyright © 2017. All rights reserved.