Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
author_email="",
license="MIT",
package_dir={'': 'src'},
py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"],
py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "pandas"],
install_requires=[],
extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]},
entry_points={}
Expand Down
146 changes: 146 additions & 0 deletions src/decision_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import pandas as pd


class TreeNode(object):
"""Define a Node object for use in a decision tree classifier."""

def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None):
"""Initialize a node object for a decision tree classifier."""
self.left = left
self.right = right
self.data = data
self.split_value = split_value
self.split_gini = split_gini
self.split_col = split_col
self.label = label

def _has_children(self):
"""Return True or False if Node has children."""
if self.right or self.left:
return True
return False


class DecisionTree(object):
"""Define a Decision Tree class object."""

def __init__(self, max_depth, min_leaf_size):
"""Initialize a Decision Tree object."""
self.max_depth = max_depth
self.min_leaf_size = min_leaf_size
self.root = None

def fit(self, data):
"""Create a tree to fit the data."""
print('building first node with data: ', data)
split_col, split_value, split_gini, split_groups = self._get_split(data)
split_col
new_node = TreeNode(data, split_value, split_gini, split_col)
self.root = new_node
print('building root left with:', split_groups[0])
self.root.left = self._build_tree(split_groups[0])
print('building root right with:', split_groups[1])
self.root.right = self._build_tree(split_groups[1])

def _build_tree(self, data, depth_count=1):
"""Given a node, build the tree."""
print('in build_tree ', data)
split_col, split_value, split_gini, split_groups = self._get_split(data)
new_node = TreeNode(data, split_value, split_gini, split_col)
try:
new_node.label = data[data.columns[-1]].mode()[0]
except:
pass
if depth_count >= self.max_depth:
return new_node
if len(data) <= self.min_leaf_size:
return new_node
print("splitting")
new_node.left = self._build_tree(split_groups[0], depth_count + 1)
new_node.right = self._build_tree(split_groups[1], depth_count + 1)
if split_gini == 0.0:
return new_node
else:
print("terminating")
return new_node
return new_node

def _can_split(self, gini, depth_count, data_size):
"""Given a gini value, determine whether or not tree can split."""
if gini == 0.0:
print("gini zero")
return False
elif depth_count >= self.max_depth:
print("bad depth")
return False
elif data_size <= self.min_leaf_size:
print("bad data size")
return False
else:
return True

def _depth(self, start=''):
"""Return the integer depth of the BST."""
def depth_wrapped(start):
if start is None:
return 0
else:
right_depth = depth_wrapped(start.right)
left_depth = depth_wrapped(start.left)
return max(right_depth, left_depth) + 1
if start is '':
return depth_wrapped(self.root)
else:
return depth_wrapped(start)

def _calculate_gini(self, groups, class_values):
"""Calculate gini for a given data_set."""
gini = 0.0
for class_value in class_values:
for group in groups:
size = len(group)
if size == 0:
continue
proportion = len(group[group[group.columns[-1]] == class_value]) / float(size)
gini += (proportion * (1.0 - proportion))
print('gini value: ', gini)
return gini

def _get_split(self, data):
"""Choose a split point with lowest gini index."""
print('getting split')
print('columns: ', data.columns.values[:-1])
print('rows: ', data.iterrows())
classes = data[data.columns[-1]].unique()
split_col, split_value, split_gini, split_groups = float('inf'), float('inf'), float('inf'), None
for col in data.columns.values[:-1]:
for row in data.iterrows():
groups = self._test_split(col, row[1][col], data)
gini = self._calculate_gini(groups, classes)
if gini < split_gini:
print('new gini: ', gini, ' at value: ', row[1][col])
split_col, split_value, split_gini, split_groups = col, row[1][col], gini, groups
print("Return from get_split. Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups)
# import pdb; pdb.set_trace()
return split_col, split_value, split_gini, split_groups

def _test_split(self, col, value, data):
"""Given a dataset, column index, and value, split the dataset."""
left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns)
for row in data.iterrows():
if row[1][col] < value:
left = left.append(row[1])
else:
right = right.append(row[1])
return left, right

def predict(self, data):
"""Given data, return labels for that data."""
curr_node = self.root
while curr_node._has_children():
if data[curr_node.label]:
curr_node = curr_node.right
else:
curr_node = curr_node.left
return curr_node.label

101 changes: 101 additions & 0 deletions src/flowers_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names
1.4,0.2,5.1,3.5,0,setosa
1.4,0.2,4.9,3.0,0,setosa
1.3,0.2,4.7,3.2,0,setosa
1.5,0.2,4.6,3.1,0,setosa
1.4,0.2,5.0,3.6,0,setosa
1.7,0.4,5.4,3.9,0,setosa
1.4,0.3,4.6,3.4,0,setosa
1.5,0.2,5.0,3.4,0,setosa
1.4,0.2,4.4,2.9,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.5,0.2,5.4,3.7,0,setosa
1.6,0.2,4.8,3.4,0,setosa
1.4,0.1,4.8,3.0,0,setosa
1.1,0.1,4.3,3.0,0,setosa
1.2,0.2,5.8,4.0,0,setosa
1.5,0.4,5.7,4.4,0,setosa
1.3,0.4,5.4,3.9,0,setosa
1.4,0.3,5.1,3.5,0,setosa
1.7,0.3,5.7,3.8,0,setosa
1.5,0.3,5.1,3.8,0,setosa
1.7,0.2,5.4,3.4,0,setosa
1.5,0.4,5.1,3.7,0,setosa
1.0,0.2,4.6,3.6,0,setosa
1.7,0.5,5.1,3.3,0,setosa
1.9,0.2,4.8,3.4,0,setosa
1.6,0.2,5.0,3.0,0,setosa
1.6,0.4,5.0,3.4,0,setosa
1.5,0.2,5.2,3.5,0,setosa
1.4,0.2,5.2,3.4,0,setosa
1.6,0.2,4.7,3.2,0,setosa
1.6,0.2,4.8,3.1,0,setosa
1.5,0.4,5.4,3.4,0,setosa
1.5,0.1,5.2,4.1,0,setosa
1.4,0.2,5.5,4.2,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.2,0.2,5.0,3.2,0,setosa
1.3,0.2,5.5,3.5,0,setosa
1.5,0.1,4.9,3.1,0,setosa
1.3,0.2,4.4,3.0,0,setosa
1.5,0.2,5.1,3.4,0,setosa
1.3,0.3,5.0,3.5,0,setosa
1.3,0.3,4.5,2.3,0,setosa
1.3,0.2,4.4,3.2,0,setosa
1.6,0.6,5.0,3.5,0,setosa
1.9,0.4,5.1,3.8,0,setosa
1.4,0.3,4.8,3.0,0,setosa
1.6,0.2,5.1,3.8,0,setosa
1.4,0.2,4.6,3.2,0,setosa
1.5,0.2,5.3,3.7,0,setosa
1.4,0.2,5.0,3.3,0,setosa
4.7,1.4,7.0,3.2,1,versicolor
4.5,1.5,6.4,3.2,1,versicolor
4.9,1.5,6.9,3.1,1,versicolor
4.0,1.3,5.5,2.3,1,versicolor
4.6,1.5,6.5,2.8,1,versicolor
4.5,1.3,5.7,2.8,1,versicolor
4.7,1.6,6.3,3.3,1,versicolor
3.3,1.0,4.9,2.4,1,versicolor
4.6,1.3,6.6,2.9,1,versicolor
3.9,1.4,5.2,2.7,1,versicolor
3.5,1.0,5.0,2.0,1,versicolor
4.2,1.5,5.9,3.0,1,versicolor
4.0,1.0,6.0,2.2,1,versicolor
4.7,1.4,6.1,2.9,1,versicolor
3.6,1.3,5.6,2.9,1,versicolor
4.4,1.4,6.7,3.1,1,versicolor
4.5,1.5,5.6,3.0,1,versicolor
4.1,1.0,5.8,2.7,1,versicolor
4.5,1.5,6.2,2.2,1,versicolor
3.9,1.1,5.6,2.5,1,versicolor
4.8,1.8,5.9,3.2,1,versicolor
4.0,1.3,6.1,2.8,1,versicolor
4.9,1.5,6.3,2.5,1,versicolor
4.7,1.2,6.1,2.8,1,versicolor
4.3,1.3,6.4,2.9,1,versicolor
4.4,1.4,6.6,3.0,1,versicolor
4.8,1.4,6.8,2.8,1,versicolor
5.0,1.7,6.7,3.0,1,versicolor
4.5,1.5,6.0,2.9,1,versicolor
3.5,1.0,5.7,2.6,1,versicolor
3.8,1.1,5.5,2.4,1,versicolor
3.7,1.0,5.5,2.4,1,versicolor
3.9,1.2,5.8,2.7,1,versicolor
5.1,1.6,6.0,2.7,1,versicolor
4.5,1.5,5.4,3.0,1,versicolor
4.5,1.6,6.0,3.4,1,versicolor
4.7,1.5,6.7,3.1,1,versicolor
4.4,1.3,6.3,2.3,1,versicolor
4.1,1.3,5.6,3.0,1,versicolor
4.0,1.3,5.5,2.5,1,versicolor
4.4,1.2,5.5,2.6,1,versicolor
4.6,1.4,6.1,3.0,1,versicolor
4.0,1.2,5.8,2.6,1,versicolor
3.3,1.0,5.0,2.3,1,versicolor
4.2,1.3,5.6,2.7,1,versicolor
4.2,1.2,5.7,3.0,1,versicolor
4.2,1.3,5.7,2.9,1,versicolor
4.3,1.3,6.2,2.9,1,versicolor
3.0,1.1,5.1,2.5,1,versicolor
4.1,1.3,5.7,2.8,1,versicolor
11 changes: 11 additions & 0 deletions src/test_dataset2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
x,y,class_names
2.771244718,1.784783929,0
1.728571309,1.169761413,0
3.678319846,2.81281357,0
3.961043357,2.61995032,0
2.999208922,2.209014212,0
7.497545867,3.162953546,1
9.00220326,3.339047188,1
7.444542326,0.476683375,1
10.12493903,3.234550982,1
6.642287351,3.319983761,1
103 changes: 103 additions & 0 deletions src/test_decision_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
import pandas as pd


DATASET2 = pd.read_csv("src/test_dataset2.csv")
DATASET2_VALUES = [
(2.771, 'x'),
(1.728, 'x'),
(3.678, 'x'),
(3.961, 'x'),
(2.999, 'x'),
(7.497, 'x'),
(9.002, 'x'),
(7.444, 'x'),
(10.124, 'x'),
(6.642, 'x'),
(1.784, 'y'),
(1.168, 'y'),
(2.812, 'y'),
(2.619, 'y'),
(2.209, 'y'),
(3.162, 'y'),
(3.339, 'y'),
(0.476, 'y'),
(3.234, 'y'),
(3.319, 'y'),
]

DATASET2_GINI = [
0.494,
0.500,
0.408,
0.278,
0.469,
0.408,
0.469,
0.278,
0.494,
0.000,
1.000,
0.494,
0.640,
0.819,
0.934,
0.278,
0.494,
0.500,
0.408,
0.469,
]

# X1 < 2.771 Gini=0.494
# X1 < 1.729 Gini=0.500
# X1 < 3.678 Gini=0.408
# X1 < 3.961 Gini=0.278
# X1 < 2.999 Gini=0.469
# X1 < 7.498 Gini=0.408
# X1 < 9.002 Gini=0.469
# X1 < 7.445 Gini=0.278
# X1 < 10.125 Gini=0.494
# X1 < 6.642 Gini=0.000
# X2 < 1.785 Gini=1.000
# X2 < 1.170 Gini=0.494
# X2 < 2.813 Gini=0.640
# X2 < 2.620 Gini=0.819
# X2 < 2.209 Gini=0.934
# X2 < 3.163 Gini=0.278
# X2 < 3.339 Gini=0.494
# X2 < 0.477 Gini=0.500
# X2 < 3.235 Gini=0.408
# X2 < 3.320 Gini=0.469


def test_test_split():
"""Test _test_split method with test dataset."""
from decision_tree import DecisionTree
data = pd.DataFrame([[1.0, 2.0, '1'], [3.0, 4.0, '0'], [5.0, 6.0, '1'], [7.0, 8.0, '0']])
left = data[data[data.columns[0]] < 3]
right = data[data[data.columns[0]] >= 3]
dtree = DecisionTree(1, 1)
assert dtree._test_split(0, 3, data)[0].equals(left)
assert dtree._test_split(0, 3, data)[1].equals(right)


def test_calculate_gini():
"""Test calculate gini with know data set."""
from decision_tree import DecisionTree
dtree = DecisionTree(1, 1)
data = pd.DataFrame(DATASET2)
for i in range(len(DATASET2_VALUES)):
left, right = dtree._test_split(DATASET2_VALUES[i][1], DATASET2_VALUES[i][0], data)
assert round(dtree._calculate_gini([left, right], [0.0, 1.0]), 3) == DATASET2_GINI[i]


def test__get_split():
"""Test get optimal split point."""
from decision_tree import DecisionTree
data_table = pd.DataFrame(DATASET2)
dtree = DecisionTree(1, 1)
split = dtree._get_split(data_table)
# import pdb; pdb.set_trace()
for i in range(len(split[3])):
assert split[3][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict()