public class SBS extends RemoveAttributeTransform
catIndexMap, numIndexMap
Constructor and Description |
---|
SBS(int minFeatures,
int maxFeatures,
ClassificationDataSet cds,
Classifier evaluater,
int folds,
double maxDecrease)
Performs SBS feature selection for a classification problem
|
SBS(int minFeatures,
int maxFeatures,
Classifier evaluater,
double maxDecrease)
Performs SBS feature selection for a classification problem
|
SBS(int minFeatures,
int maxFeatures,
RegressionDataSet rds,
Regressor evaluater,
int folds,
double maxDecrease)
Performs SBS feature selection for a regression problem
|
SBS(int minFeatures,
int maxFeatures,
Regressor evaluater,
double maxDecrease)
Performs SBS feature selection for a regression problem
|
Modifier and Type | Method and Description |
---|---|
SBS |
clone() |
void |
fit(DataSet data)
Fits this transform to the given dataset.
|
int |
getFolds() |
double |
getMaxDecrease()
Returns the maximum allowable decrease in accuracy from one set of
features to the next
|
int |
getMaxFeatures()
Returns the maximum number of features to find
|
int |
getMinFeatures()
Returns the minimum number of features to find
|
Set<Integer> |
getSelectedCategorical()
Returns a copy of the set of categorical features selected by the search
algorithm
|
Set<Integer> |
getSelectedNumerical()
Returns a copy of the set of numerical features selected by the search
algorithm.
|
protected static int |
SBSRemoveFeature(Set<Integer> available,
DataSet dataSet,
Set<Integer> catToRemove,
Set<Integer> numToRemove,
Set<Integer> catSelecteed,
Set<Integer> numSelected,
Object evaluater,
int folds,
Random rand,
int maxFeatures,
double[] PbestScore,
double maxDecrease)
Attempts to remove one feature from the list while maintaining its
accuracy
|
void |
setFolds(int folds)
Sets the number of folds to use for cross validation when estimating the error rate
|
void |
setMaxDecrease(double maxDecrease)
Sets the maximum allowable decrease in accuracy (increase in error) from
the previous set of features to the new current set.
|
void |
setMaxFeatures(int maxFeatures)
Sets the maximum number of features that must be selected
|
void |
setMinFeatures(int minFeatures)
Sets the minimum number of features that must be selected
|
consolidate, getKeptNominal, getKeptNumeric, getReverseNominalMap, getReverseNumericMap, setUp, transform
public SBS(int minFeatures, int maxFeatures, Classifier evaluater, double maxDecrease)
minFeatures
- the minimum number of features to findmaxFeatures
- the maximum number of features to findevaluater
- the classifier to use in determining accuracy given a
feature subsetfolds
- the number of cross validation folds to use in selectionmaxDecrease
- the maximum tolerable decrease in accuracy in accuracy
when a feature is removedpublic SBS(int minFeatures, int maxFeatures, ClassificationDataSet cds, Classifier evaluater, int folds, double maxDecrease)
minFeatures
- the minimum number of features to findmaxFeatures
- the maximum number of features to findcds
- the data set to perform feature selection onevaluater
- the classifier to use in determining accuracy given a
feature subsetfolds
- the number of cross validation folds to use in selectionmaxDecrease
- the maximum tolerable decrease in accuracy in accuracy
when a feature is removedpublic SBS(int minFeatures, int maxFeatures, Regressor evaluater, double maxDecrease)
minFeatures
- the minimum number of features to findmaxFeatures
- the maximum number of features to findevaluater
- the regressor to use in determining accuracy given a
feature subsetmaxDecrease
- the maximum tolerable increase in the error rate when
a feature is removedpublic SBS(int minFeatures, int maxFeatures, RegressionDataSet rds, Regressor evaluater, int folds, double maxDecrease)
minFeatures
- the minimum number of features to findmaxFeatures
- the maximum number of features to findrds
- the data set to perform feature selection onevaluater
- the regressor to use in determining accuracy given a
feature subsetfolds
- the number of cross validation folds to use in selectionmaxDecrease
- the maximum tolerable increase in the error rate when
a feature is removedpublic void fit(DataSet data)
DataTransform
FailedToFitException
exception may be
thrown.fit
in interface DataTransform
fit
in class RemoveAttributeTransform
data
- the dataset to fir this transform topublic SBS clone()
clone
in interface DataTransform
clone
in class RemoveAttributeTransform
public Set<Integer> getSelectedCategorical()
public Set<Integer> getSelectedNumerical()
protected static int SBSRemoveFeature(Set<Integer> available, DataSet dataSet, Set<Integer> catToRemove, Set<Integer> numToRemove, Set<Integer> catSelecteed, Set<Integer> numSelected, Object evaluater, int folds, Random rand, int maxFeatures, double[] PbestScore, double maxDecrease)
available
- the set of available features from [0, n) to consider
for removaldataSet
- the original data set to perform feature selection fromcatToRemove
- the current set of categorical features to removenumToRemove
- the current set of numerical features to removecatSelecteed
- the current set of categorical features we are keepingnumSelected
- the current set of numerical features we are keepingevaluater
- the classifier or regressor to perform evaluations withfolds
- the number of cross validation folds to determine performancerand
- the source of randomnessmaxFeatures
- the maximum allowable number of featuresPbestScore
- an array to behave as a pointer to the best score seen
so farmaxDecrease
- the maximum allowable decrease in accuracy from the
best accuracy we seepublic void setMaxDecrease(double maxDecrease)
maxDecrease
- the maximum allowable decrease in the accuracy from
removing a featurepublic double getMaxDecrease()
public void setMinFeatures(int minFeatures)
minFeatures
- the minimum number of features to learnpublic int getMinFeatures()
public void setMaxFeatures(int maxFeatures)
maxFeatures
- the maximum number of features to findpublic int getMaxFeatures()
public void setFolds(int folds)
folds
- the number of folds to use for cross validation when estimating the error ratepublic int getFolds()
Copyright © 2017. All rights reserved.