public class DecisionTree extends Object implements Classifier, Regressor, Parameterized, TreeLearner
DecisionStumps
. How this
tree performs is controlled by pruning method selected, and the methods used
in the stump.Modifier and Type | Class and Description |
---|---|
protected static class |
DecisionTree.Node |
Modifier | Constructor and Description |
---|---|
|
DecisionTree()
Creates a Decision Tree that uses
TreePruner.PruningMethod.REDUCED_ERROR
pruning on a held out 10% of the data. |
protected |
DecisionTree(DecisionTree toCopy)
Copy constructor
|
|
DecisionTree(int maxDepth)
Creates a Decision Tree that does not do any pruning, and is built out only to the specified depth
|
|
DecisionTree(int maxDepth,
int minSamples,
TreePruner.PruningMethod pruningMethod,
double testProportion)
Creates a new decision tree classifier
|
Modifier and Type | Method and Description |
---|---|
CategoricalResults |
classify(DataPoint data)
Performs classification on the given data point.
|
DecisionTree |
clone() |
static DecisionTree |
getC45Tree()
Returns a Decision Tree with settings initialized so that its behavior is
approximately that of the C4.5 decision tree algorithm when used on
classification data.
|
ImpurityScore.ImpurityMeasure |
getGainMethod() |
int |
getMaxDepth()
The maximum depth that this classifier may build trees to.
|
int |
getMinResultSplitSize()
Returns the minimum result split size that may be considered for use as
the attribute to split on.
|
int |
getMinSamples()
The minimum number of samples needed at each step in order to continue branching
|
Parameter |
getParameter(String paramName)
Returns the parameter with the given name.
|
List<Parameter> |
getParameters()
Returns the list of parameters that can be altered for this learner.
|
TreePruner.PruningMethod |
getPruningMethod()
Returns the method of pruning used after tree construction
|
double |
getTestProportion()
Returns the proportion of the training set that is put aside to perform pruning with
|
TreeNodeVisitor |
getTreeNodeVisitor()
Obtains a node visitor for the tree learner that can be used to traverse
and predict from the learned tree
|
protected DecisionTree.Node |
makeNodeC(List<DataPointPair<Integer>> dataPoints,
Set<Integer> options,
int depth,
ExecutorService threadPool,
ModifiableCountDownLatch mcdl)
Makes a new node for classification
|
protected DecisionTree.Node |
makeNodeR(List<DataPointPair<Double>> dataPoints,
Set<Integer> options,
int depth,
ExecutorService threadPool,
ModifiableCountDownLatch mcdl)
Makes a new node for regression
|
double |
regress(DataPoint data) |
void |
setGainMethod(ImpurityScore.ImpurityMeasure gainMethod) |
void |
setMaxDepth(int maxDepth)
Sets the maximum depth that this classifier may build trees to.
|
void |
setMinResultSplitSize(int size)
When a split is made, it may be that outliers cause the split to
segregate a minority of points from the majority.
|
void |
setMinSamples(int minSamples)
Sets the minimum number of samples needed at each step in order to continue branching
|
void |
setPruningMethod(TreePruner.PruningMethod pruningMethod)
Sets the method of pruning that will be used after tree construction
|
void |
setTestProportion(double testProportion)
Sets the proportion of the training set that is put aside to perform pruning with.
|
boolean |
supportsWeightedData()
Indicates whether the model knows how to train using weighted data points.
|
void |
train(RegressionDataSet dataSet) |
void |
train(RegressionDataSet dataSet,
ExecutorService threadPool) |
void |
train(RegressionDataSet dataSet,
Set<Integer> options) |
void |
train(RegressionDataSet dataSet,
Set<Integer> options,
ExecutorService threadPool) |
void |
trainC(ClassificationDataSet dataSet)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet,
ExecutorService threadPool)
Trains the classifier and constructs a model for classification using the
given data set.
|
void |
trainC(ClassificationDataSet dataSet,
Set<Integer> options) |
protected void |
trainC(ClassificationDataSet dataSet,
Set<Integer> options,
ExecutorService threadPool)
Performs exactly the same as
trainC(jsat.classifiers.ClassificationDataSet, java.util.concurrent.ExecutorService) ,
but the user can specify a subset of the features to be considered. |
public DecisionTree()
TreePruner.PruningMethod.REDUCED_ERROR
pruning on a held out 10% of the data.public DecisionTree(int maxDepth)
maxDepth
- public DecisionTree(int maxDepth, int minSamples, TreePruner.PruningMethod pruningMethod, double testProportion)
maxDepth
- the maximum depth of the tree to createminSamples
- the minimum number of samples needed to continue branchingpruningMethod
- the method of pruning to use after constructiontestProportion
- the proportion of the data set to put aside to use for pruningprotected DecisionTree(DecisionTree toCopy)
toCopy
- the object to copypublic void train(RegressionDataSet dataSet, ExecutorService threadPool)
public void train(RegressionDataSet dataSet, Set<Integer> options)
public void train(RegressionDataSet dataSet, Set<Integer> options, ExecutorService threadPool)
public void train(RegressionDataSet dataSet)
public static DecisionTree getC45Tree()
public void setGainMethod(ImpurityScore.ImpurityMeasure gainMethod)
public ImpurityScore.ImpurityMeasure getGainMethod()
public void setMinResultSplitSize(int size)
size
- the minimum result split size to usepublic int getMinResultSplitSize()
public void setMaxDepth(int maxDepth)
maxDepth
- the maximum depth of the trained treepublic int getMaxDepth()
public void setMinSamples(int minSamples)
minSamples
- the minimum number of samples needed to branchpublic int getMinSamples()
public void setPruningMethod(TreePruner.PruningMethod pruningMethod)
pruningMethod
- the method of pruning that will be used after tree constructionTreePruner.PruningMethod
public TreePruner.PruningMethod getPruningMethod()
public double getTestProportion()
public void setTestProportion(double testProportion)
testProportion
- the proportion, must be in the range [0, 1]public CategoricalResults classify(DataPoint data)
Classifier
classify
in interface Classifier
data
- the data point to classifypublic void trainC(ClassificationDataSet dataSet, ExecutorService threadPool)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onthreadPool
- the source of threads to use.protected void trainC(ClassificationDataSet dataSet, Set<Integer> options, ExecutorService threadPool)
trainC(jsat.classifiers.ClassificationDataSet, java.util.concurrent.ExecutorService)
,
but the user can specify a subset of the features to be considered.dataSet
- the data set to train fromoptions
- the subset of features to split onthreadPool
- the source of threads for training.protected DecisionTree.Node makeNodeC(List<DataPointPair<Integer>> dataPoints, Set<Integer> options, int depth, ExecutorService threadPool, ModifiableCountDownLatch mcdl)
dataPoints
- the list of data points paired with their classoptions
- the attributes that this tree may select fromdepth
- the current depth of the treethreadPool
- the source of threadsmcdl
- count down latchprotected DecisionTree.Node makeNodeR(List<DataPointPair<Double>> dataPoints, Set<Integer> options, int depth, ExecutorService threadPool, ModifiableCountDownLatch mcdl)
dataPoints
- the list of data points paired with their associated real valueoptions
- the attributes that this tree may select fromdepth
- the current depth of the treethreadPool
- the source of threadsmcdl
- count down latchpublic void trainC(ClassificationDataSet dataSet)
Classifier
trainC
in interface Classifier
dataSet
- the data set to train onpublic void trainC(ClassificationDataSet dataSet, Set<Integer> options)
public boolean supportsWeightedData()
Classifier
supportsWeightedData
in interface Classifier
supportsWeightedData
in interface Regressor
public DecisionTree clone()
public TreeNodeVisitor getTreeNodeVisitor()
TreeLearner
getTreeNodeVisitor
in interface TreeLearner
public List<Parameter> getParameters()
Parameterized
getParameters
in interface Parameterized
public Parameter getParameter(String paramName)
Parameterized
getParameter
in interface Parameterized
paramName
- the name of the parameter to obtainCopyright © 2017. All rights reserved.