diff --git a/setup.py b/setup.py index 24d582d..27ad944 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author_email="", license="MIT", package_dir={'': 'src'}, - py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"], + py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "pandas"], install_requires=[], extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]}, entry_points={} diff --git a/src/decision_tree.py b/src/decision_tree.py new file mode 100644 index 0000000..d36b78b --- /dev/null +++ b/src/decision_tree.py @@ -0,0 +1,146 @@ +import pandas as pd + + +class TreeNode(object): + """Define a Node object for use in a decision tree classifier.""" + + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None): + """Initialize a node object for a decision tree classifier.""" + self.left = left + self.right = right + self.data = data + self.split_value = split_value + self.split_gini = split_gini + self.split_col = split_col + self.label = label + + def _has_children(self): + """Return True or False if Node has children.""" + if self.right or self.left: + return True + return False + + +class DecisionTree(object): + """Define a Decision Tree class object.""" + + def __init__(self, max_depth, min_leaf_size): + """Initialize a Decision Tree object.""" + self.max_depth = max_depth + self.min_leaf_size = min_leaf_size + self.root = None + + def fit(self, data): + """Create a tree to fit the data.""" + print('building first node with data: ', data) + split_col, split_value, split_gini, split_groups = self._get_split(data) + split_col + new_node = TreeNode(data, split_value, split_gini, split_col) + self.root = new_node + print('building root left with:', split_groups[0]) + self.root.left = self._build_tree(split_groups[0]) + print('building root right with:', split_groups[1]) + self.root.right = self._build_tree(split_groups[1]) + + def _build_tree(self, data, depth_count=1): + """Given a node, build the tree.""" + print('in build_tree ', data) + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + try: + new_node.label = data[data.columns[-1]].mode()[0] + except: + pass + if depth_count >= self.max_depth: + return new_node + if len(data) <= self.min_leaf_size: + return new_node + print("splitting") + new_node.left = self._build_tree(split_groups[0], depth_count + 1) + new_node.right = self._build_tree(split_groups[1], depth_count + 1) + if split_gini == 0.0: + return new_node + else: + print("terminating") + return new_node + return new_node + + def _can_split(self, gini, depth_count, data_size): + """Given a gini value, determine whether or not tree can split.""" + if gini == 0.0: + print("gini zero") + return False + elif depth_count >= self.max_depth: + print("bad depth") + return False + elif data_size <= self.min_leaf_size: + print("bad data size") + return False + else: + return True + + def _depth(self, start=''): + """Return the integer depth of the BST.""" + def depth_wrapped(start): + if start is None: + return 0 + else: + right_depth = depth_wrapped(start.right) + left_depth = depth_wrapped(start.left) + return max(right_depth, left_depth) + 1 + if start is '': + return depth_wrapped(self.root) + else: + return depth_wrapped(start) + + def _calculate_gini(self, groups, class_values): + """Calculate gini for a given data_set.""" + gini = 0.0 + for class_value in class_values: + for group in groups: + size = len(group) + if size == 0: + continue + proportion = len(group[group[group.columns[-1]] == class_value]) / float(size) + gini += (proportion * (1.0 - proportion)) + print('gini value: ', gini) + return gini + + def _get_split(self, data): + """Choose a split point with lowest gini index.""" + print('getting split') + print('columns: ', data.columns.values[:-1]) + print('rows: ', data.iterrows()) + classes = data[data.columns[-1]].unique() + split_col, split_value, split_gini, split_groups = float('inf'), float('inf'), float('inf'), None + for col in data.columns.values[:-1]: + for row in data.iterrows(): + groups = self._test_split(col, row[1][col], data) + gini = self._calculate_gini(groups, classes) + if gini < split_gini: + print('new gini: ', gini, ' at value: ', row[1][col]) + split_col, split_value, split_gini, split_groups = col, row[1][col], gini, groups + print("Return from get_split. Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) + # import pdb; pdb.set_trace() + return split_col, split_value, split_gini, split_groups + + def _test_split(self, col, value, data): + """Given a dataset, column index, and value, split the dataset.""" + left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns) + for row in data.iterrows(): + if row[1][col] < value: + left = left.append(row[1]) + else: + right = right.append(row[1]) + return left, right + + def predict(self, data): + """Given data, return labels for that data.""" + curr_node = self.root + while curr_node._has_children(): + if data[curr_node.label]: + curr_node = curr_node.right + else: + curr_node = curr_node.left + return curr_node.label + diff --git a/src/flowers_data.csv b/src/flowers_data.csv new file mode 100644 index 0000000..63fed67 --- /dev/null +++ b/src/flowers_data.csv @@ -0,0 +1,101 @@ +petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names +1.4,0.2,5.1,3.5,0,setosa +1.4,0.2,4.9,3.0,0,setosa +1.3,0.2,4.7,3.2,0,setosa +1.5,0.2,4.6,3.1,0,setosa +1.4,0.2,5.0,3.6,0,setosa +1.7,0.4,5.4,3.9,0,setosa +1.4,0.3,4.6,3.4,0,setosa +1.5,0.2,5.0,3.4,0,setosa +1.4,0.2,4.4,2.9,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.5,0.2,5.4,3.7,0,setosa +1.6,0.2,4.8,3.4,0,setosa +1.4,0.1,4.8,3.0,0,setosa +1.1,0.1,4.3,3.0,0,setosa +1.2,0.2,5.8,4.0,0,setosa +1.5,0.4,5.7,4.4,0,setosa +1.3,0.4,5.4,3.9,0,setosa +1.4,0.3,5.1,3.5,0,setosa +1.7,0.3,5.7,3.8,0,setosa +1.5,0.3,5.1,3.8,0,setosa +1.7,0.2,5.4,3.4,0,setosa +1.5,0.4,5.1,3.7,0,setosa +1.0,0.2,4.6,3.6,0,setosa +1.7,0.5,5.1,3.3,0,setosa +1.9,0.2,4.8,3.4,0,setosa +1.6,0.2,5.0,3.0,0,setosa +1.6,0.4,5.0,3.4,0,setosa +1.5,0.2,5.2,3.5,0,setosa +1.4,0.2,5.2,3.4,0,setosa +1.6,0.2,4.7,3.2,0,setosa +1.6,0.2,4.8,3.1,0,setosa +1.5,0.4,5.4,3.4,0,setosa +1.5,0.1,5.2,4.1,0,setosa +1.4,0.2,5.5,4.2,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.2,0.2,5.0,3.2,0,setosa +1.3,0.2,5.5,3.5,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.3,0.2,4.4,3.0,0,setosa +1.5,0.2,5.1,3.4,0,setosa +1.3,0.3,5.0,3.5,0,setosa +1.3,0.3,4.5,2.3,0,setosa +1.3,0.2,4.4,3.2,0,setosa +1.6,0.6,5.0,3.5,0,setosa +1.9,0.4,5.1,3.8,0,setosa +1.4,0.3,4.8,3.0,0,setosa +1.6,0.2,5.1,3.8,0,setosa +1.4,0.2,4.6,3.2,0,setosa +1.5,0.2,5.3,3.7,0,setosa +1.4,0.2,5.0,3.3,0,setosa +4.7,1.4,7.0,3.2,1,versicolor +4.5,1.5,6.4,3.2,1,versicolor +4.9,1.5,6.9,3.1,1,versicolor +4.0,1.3,5.5,2.3,1,versicolor +4.6,1.5,6.5,2.8,1,versicolor +4.5,1.3,5.7,2.8,1,versicolor +4.7,1.6,6.3,3.3,1,versicolor +3.3,1.0,4.9,2.4,1,versicolor +4.6,1.3,6.6,2.9,1,versicolor +3.9,1.4,5.2,2.7,1,versicolor +3.5,1.0,5.0,2.0,1,versicolor +4.2,1.5,5.9,3.0,1,versicolor +4.0,1.0,6.0,2.2,1,versicolor +4.7,1.4,6.1,2.9,1,versicolor +3.6,1.3,5.6,2.9,1,versicolor +4.4,1.4,6.7,3.1,1,versicolor +4.5,1.5,5.6,3.0,1,versicolor +4.1,1.0,5.8,2.7,1,versicolor +4.5,1.5,6.2,2.2,1,versicolor +3.9,1.1,5.6,2.5,1,versicolor +4.8,1.8,5.9,3.2,1,versicolor +4.0,1.3,6.1,2.8,1,versicolor +4.9,1.5,6.3,2.5,1,versicolor +4.7,1.2,6.1,2.8,1,versicolor +4.3,1.3,6.4,2.9,1,versicolor +4.4,1.4,6.6,3.0,1,versicolor +4.8,1.4,6.8,2.8,1,versicolor +5.0,1.7,6.7,3.0,1,versicolor +4.5,1.5,6.0,2.9,1,versicolor +3.5,1.0,5.7,2.6,1,versicolor +3.8,1.1,5.5,2.4,1,versicolor +3.7,1.0,5.5,2.4,1,versicolor +3.9,1.2,5.8,2.7,1,versicolor +5.1,1.6,6.0,2.7,1,versicolor +4.5,1.5,5.4,3.0,1,versicolor +4.5,1.6,6.0,3.4,1,versicolor +4.7,1.5,6.7,3.1,1,versicolor +4.4,1.3,6.3,2.3,1,versicolor +4.1,1.3,5.6,3.0,1,versicolor +4.0,1.3,5.5,2.5,1,versicolor +4.4,1.2,5.5,2.6,1,versicolor +4.6,1.4,6.1,3.0,1,versicolor +4.0,1.2,5.8,2.6,1,versicolor +3.3,1.0,5.0,2.3,1,versicolor +4.2,1.3,5.6,2.7,1,versicolor +4.2,1.2,5.7,3.0,1,versicolor +4.2,1.3,5.7,2.9,1,versicolor +4.3,1.3,6.2,2.9,1,versicolor +3.0,1.1,5.1,2.5,1,versicolor +4.1,1.3,5.7,2.8,1,versicolor diff --git a/src/test_dataset2.csv b/src/test_dataset2.csv new file mode 100644 index 0000000..e2a182d --- /dev/null +++ b/src/test_dataset2.csv @@ -0,0 +1,11 @@ +x,y,class_names +2.771244718,1.784783929,0 +1.728571309,1.169761413,0 +3.678319846,2.81281357,0 +3.961043357,2.61995032,0 +2.999208922,2.209014212,0 +7.497545867,3.162953546,1 +9.00220326,3.339047188,1 +7.444542326,0.476683375,1 +10.12493903,3.234550982,1 +6.642287351,3.319983761,1 \ No newline at end of file diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py new file mode 100644 index 0000000..b0f956f --- /dev/null +++ b/src/test_decision_tree.py @@ -0,0 +1,103 @@ +import pytest +import pandas as pd + + +DATASET2 = pd.read_csv("src/test_dataset2.csv") +DATASET2_VALUES = [ + (2.771, 'x'), + (1.728, 'x'), + (3.678, 'x'), + (3.961, 'x'), + (2.999, 'x'), + (7.497, 'x'), + (9.002, 'x'), + (7.444, 'x'), + (10.124, 'x'), + (6.642, 'x'), + (1.784, 'y'), + (1.168, 'y'), + (2.812, 'y'), + (2.619, 'y'), + (2.209, 'y'), + (3.162, 'y'), + (3.339, 'y'), + (0.476, 'y'), + (3.234, 'y'), + (3.319, 'y'), +] + +DATASET2_GINI = [ + 0.494, + 0.500, + 0.408, + 0.278, + 0.469, + 0.408, + 0.469, + 0.278, + 0.494, + 0.000, + 1.000, + 0.494, + 0.640, + 0.819, + 0.934, + 0.278, + 0.494, + 0.500, + 0.408, + 0.469, +] + +# X1 < 2.771 Gini=0.494 +# X1 < 1.729 Gini=0.500 +# X1 < 3.678 Gini=0.408 +# X1 < 3.961 Gini=0.278 +# X1 < 2.999 Gini=0.469 +# X1 < 7.498 Gini=0.408 +# X1 < 9.002 Gini=0.469 +# X1 < 7.445 Gini=0.278 +# X1 < 10.125 Gini=0.494 +# X1 < 6.642 Gini=0.000 +# X2 < 1.785 Gini=1.000 +# X2 < 1.170 Gini=0.494 +# X2 < 2.813 Gini=0.640 +# X2 < 2.620 Gini=0.819 +# X2 < 2.209 Gini=0.934 +# X2 < 3.163 Gini=0.278 +# X2 < 3.339 Gini=0.494 +# X2 < 0.477 Gini=0.500 +# X2 < 3.235 Gini=0.408 +# X2 < 3.320 Gini=0.469 + + +def test_test_split(): + """Test _test_split method with test dataset.""" + from decision_tree import DecisionTree + data = pd.DataFrame([[1.0, 2.0, '1'], [3.0, 4.0, '0'], [5.0, 6.0, '1'], [7.0, 8.0, '0']]) + left = data[data[data.columns[0]] < 3] + right = data[data[data.columns[0]] >= 3] + dtree = DecisionTree(1, 1) + assert dtree._test_split(0, 3, data)[0].equals(left) + assert dtree._test_split(0, 3, data)[1].equals(right) + + +def test_calculate_gini(): + """Test calculate gini with know data set.""" + from decision_tree import DecisionTree + dtree = DecisionTree(1, 1) + data = pd.DataFrame(DATASET2) + for i in range(len(DATASET2_VALUES)): + left, right = dtree._test_split(DATASET2_VALUES[i][1], DATASET2_VALUES[i][0], data) + assert round(dtree._calculate_gini([left, right], [0.0, 1.0]), 3) == DATASET2_GINI[i] + + +def test__get_split(): + """Test get optimal split point.""" + from decision_tree import DecisionTree + data_table = pd.DataFrame(DATASET2) + dtree = DecisionTree(1, 1) + split = dtree._get_split(data_table) + # import pdb; pdb.set_trace() + for i in range(len(split[3])): + assert split[3][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file