From 2f72f7d2626a0f38abd52692c8cf4b247a9f9dff Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:05:11 -0800 Subject: [PATCH 01/13] Initial commit for decision tree. --- src/decision_tree.py | 0 src/test_decision_tree.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/decision_tree.py create mode 100644 src/test_decision_tree.py diff --git a/src/decision_tree.py b/src/decision_tree.py new file mode 100644 index 0000000..e69de29 diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py new file mode 100644 index 0000000..e69de29 From 2ad032de48b13b1d3a8fb442873792b0e2ebe13e Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:06:31 -0800 Subject: [PATCH 02/13] including data file. --- src/flowers_data.csv | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 src/flowers_data.csv diff --git a/src/flowers_data.csv b/src/flowers_data.csv new file mode 100644 index 0000000..63fed67 --- /dev/null +++ b/src/flowers_data.csv @@ -0,0 +1,101 @@ +petal length (cm),petal width (cm),sepal length (cm),sepal width (cm),target,class_names +1.4,0.2,5.1,3.5,0,setosa +1.4,0.2,4.9,3.0,0,setosa +1.3,0.2,4.7,3.2,0,setosa +1.5,0.2,4.6,3.1,0,setosa +1.4,0.2,5.0,3.6,0,setosa +1.7,0.4,5.4,3.9,0,setosa +1.4,0.3,4.6,3.4,0,setosa +1.5,0.2,5.0,3.4,0,setosa +1.4,0.2,4.4,2.9,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.5,0.2,5.4,3.7,0,setosa +1.6,0.2,4.8,3.4,0,setosa +1.4,0.1,4.8,3.0,0,setosa +1.1,0.1,4.3,3.0,0,setosa +1.2,0.2,5.8,4.0,0,setosa +1.5,0.4,5.7,4.4,0,setosa +1.3,0.4,5.4,3.9,0,setosa +1.4,0.3,5.1,3.5,0,setosa +1.7,0.3,5.7,3.8,0,setosa +1.5,0.3,5.1,3.8,0,setosa +1.7,0.2,5.4,3.4,0,setosa +1.5,0.4,5.1,3.7,0,setosa +1.0,0.2,4.6,3.6,0,setosa +1.7,0.5,5.1,3.3,0,setosa +1.9,0.2,4.8,3.4,0,setosa +1.6,0.2,5.0,3.0,0,setosa +1.6,0.4,5.0,3.4,0,setosa +1.5,0.2,5.2,3.5,0,setosa +1.4,0.2,5.2,3.4,0,setosa +1.6,0.2,4.7,3.2,0,setosa +1.6,0.2,4.8,3.1,0,setosa +1.5,0.4,5.4,3.4,0,setosa +1.5,0.1,5.2,4.1,0,setosa +1.4,0.2,5.5,4.2,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.2,0.2,5.0,3.2,0,setosa +1.3,0.2,5.5,3.5,0,setosa +1.5,0.1,4.9,3.1,0,setosa +1.3,0.2,4.4,3.0,0,setosa +1.5,0.2,5.1,3.4,0,setosa +1.3,0.3,5.0,3.5,0,setosa +1.3,0.3,4.5,2.3,0,setosa +1.3,0.2,4.4,3.2,0,setosa +1.6,0.6,5.0,3.5,0,setosa +1.9,0.4,5.1,3.8,0,setosa +1.4,0.3,4.8,3.0,0,setosa +1.6,0.2,5.1,3.8,0,setosa +1.4,0.2,4.6,3.2,0,setosa +1.5,0.2,5.3,3.7,0,setosa +1.4,0.2,5.0,3.3,0,setosa +4.7,1.4,7.0,3.2,1,versicolor +4.5,1.5,6.4,3.2,1,versicolor +4.9,1.5,6.9,3.1,1,versicolor +4.0,1.3,5.5,2.3,1,versicolor +4.6,1.5,6.5,2.8,1,versicolor +4.5,1.3,5.7,2.8,1,versicolor +4.7,1.6,6.3,3.3,1,versicolor +3.3,1.0,4.9,2.4,1,versicolor +4.6,1.3,6.6,2.9,1,versicolor +3.9,1.4,5.2,2.7,1,versicolor +3.5,1.0,5.0,2.0,1,versicolor +4.2,1.5,5.9,3.0,1,versicolor +4.0,1.0,6.0,2.2,1,versicolor +4.7,1.4,6.1,2.9,1,versicolor +3.6,1.3,5.6,2.9,1,versicolor +4.4,1.4,6.7,3.1,1,versicolor +4.5,1.5,5.6,3.0,1,versicolor +4.1,1.0,5.8,2.7,1,versicolor +4.5,1.5,6.2,2.2,1,versicolor +3.9,1.1,5.6,2.5,1,versicolor +4.8,1.8,5.9,3.2,1,versicolor +4.0,1.3,6.1,2.8,1,versicolor +4.9,1.5,6.3,2.5,1,versicolor +4.7,1.2,6.1,2.8,1,versicolor +4.3,1.3,6.4,2.9,1,versicolor +4.4,1.4,6.6,3.0,1,versicolor +4.8,1.4,6.8,2.8,1,versicolor +5.0,1.7,6.7,3.0,1,versicolor +4.5,1.5,6.0,2.9,1,versicolor +3.5,1.0,5.7,2.6,1,versicolor +3.8,1.1,5.5,2.4,1,versicolor +3.7,1.0,5.5,2.4,1,versicolor +3.9,1.2,5.8,2.7,1,versicolor +5.1,1.6,6.0,2.7,1,versicolor +4.5,1.5,5.4,3.0,1,versicolor +4.5,1.6,6.0,3.4,1,versicolor +4.7,1.5,6.7,3.1,1,versicolor +4.4,1.3,6.3,2.3,1,versicolor +4.1,1.3,5.6,3.0,1,versicolor +4.0,1.3,5.5,2.5,1,versicolor +4.4,1.2,5.5,2.6,1,versicolor +4.6,1.4,6.1,3.0,1,versicolor +4.0,1.2,5.8,2.6,1,versicolor +3.3,1.0,5.0,2.3,1,versicolor +4.2,1.3,5.6,2.7,1,versicolor +4.2,1.2,5.7,3.0,1,versicolor +4.2,1.3,5.7,2.9,1,versicolor +4.3,1.3,6.2,2.9,1,versicolor +3.0,1.1,5.1,2.5,1,versicolor +4.1,1.3,5.7,2.8,1,versicolor From a944f660a7f619d28ccf88662f1b62b19273d5d8 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Wed, 8 Feb 2017 14:13:34 -0800 Subject: [PATCH 03/13] Wrote out initial tree class. --- src/decision_tree.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/decision_tree.py b/src/decision_tree.py index e69de29..5505c7c 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -0,0 +1,17 @@ + +class DecisionTree(object): + """Define a Decision Tree class object.""" + + def __init__(self, max_depth, min_leaf_size): + """Initialize a Decision Tree object.""" + self.max_depth = max_depth + self.min_leaf_size = min_leaf_size + self.root = None + + def fit(self, data): + """Create a tree to fit the data.""" + pass + + def predict(self, data): + """Given data, return labels for that data.""" + pass From cb95f603ca1dade94be7d54417784a1b5ceffc4a Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 11:17:07 -0800 Subject: [PATCH 04/13] added some stuff. --- src/decision_tree.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/decision_tree.py b/src/decision_tree.py index 5505c7c..392e5e9 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,4 +1,17 @@ +class TreeNode(object): + """Define a Node object for use in a decision tree classifier.""" + + def __init__(self, split_value, data_set, label=None, left=None, right=None, parent=None): + """Initialize a node object for a decision tree classifier.""" + self.left = left + self.right = right + self.parent = parent + self.data_set = data_set + self.split_value = split_value + self.label = label + + class DecisionTree(object): """Define a Decision Tree class object.""" @@ -7,11 +20,17 @@ def __init__(self, max_depth, min_leaf_size): self.max_depth = max_depth self.min_leaf_size = min_leaf_size self.root = None + self.class_values = [] def fit(self, data): """Create a tree to fit the data.""" pass + def _calculate_gini(self, data): + """Calculate gini for a given data_set.""" + pass + def predict(self, data): """Given data, return labels for that data.""" pass + From 38e6948004c6cdc774aa5796f3556fe01d437b7d Mon Sep 17 00:00:00 2001 From: pasaunders Date: Thu, 9 Feb 2017 13:14:06 -0800 Subject: [PATCH 05/13] gini and split functions added --- src/decision_tree.py | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 392e5e9..fb67629 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,3 +1,5 @@ +import pandas as pd + class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" @@ -26,11 +28,45 @@ def fit(self, data): """Create a tree to fit the data.""" pass - def _calculate_gini(self, data): + def _calculate_gini(self, groups, class_values): """Calculate gini for a given data_set.""" - pass + gini = 0.0 + for class_value in class_values: + for group in groups: + size = len(group) + if size == 0: + continue + proportion = [row[-1] for row in group].count(class_value) / float(size) + gini += (proportion * (1.0 - proportion)) + return gini + + def _get_split(self, data): + """Choose a split point with lowest gini index.""" + classes = data["class_names"].unique() + split_col_index, split_value, split_gini, split_groups =\ + float('inf'), float('inf'), float('inf'), None + for col_index in range(len(data.columns) - 2): + for row in data: + groups = self._test_split(col_index, row[col_index], data) + gini = self._calculate_gini(groups, classes) + if gini < split_gini: + split_col_index, split_value, split_gini, split_groups =\ + col_index, row[col_index], gini, groups + return split_col_index, split_value, split_groups + + def _calculate_split(self, data): + lowest_gini = 1.0 + lowest_row = None + lowest_col = None + for row in data: + for col in data: + gini = self._calculate_gini(row, col) + if gini < lowest_gini: + lowest_gini = gini + lowest_row = row + lowest_col = col + return lowest_row, lowest_col def predict(self, data): """Given data, return labels for that data.""" pass - From 7b6964de7d790b88621fc98eafcda2d4560125bc Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 13:17:16 -0800 Subject: [PATCH 06/13] test database. --- src/decision_tree.py | 12 +++++++++++- src/test_dataset2.csv | 11 +++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 src/test_dataset2.csv diff --git a/src/decision_tree.py b/src/decision_tree.py index 392e5e9..b1c1708 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -1,3 +1,5 @@ +import pandas as pd + class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" @@ -28,7 +30,15 @@ def fit(self, data): def _calculate_gini(self, data): """Calculate gini for a given data_set.""" - pass + gini = 0.0 + for class_name in data["class_names"].unique(): + for col in data.columns[:-2]: + total_size = len(data) + if total_size == 0: + continue + proportion = [row[-1] for row in col].count(class_name) / float(total_size) + gini += (proportion * (1.0 - proportion)) + return gini def predict(self, data): """Given data, return labels for that data.""" diff --git a/src/test_dataset2.csv b/src/test_dataset2.csv new file mode 100644 index 0000000..e2a182d --- /dev/null +++ b/src/test_dataset2.csv @@ -0,0 +1,11 @@ +x,y,class_names +2.771244718,1.784783929,0 +1.728571309,1.169761413,0 +3.678319846,2.81281357,0 +3.961043357,2.61995032,0 +2.999208922,2.209014212,0 +7.497545867,3.162953546,1 +9.00220326,3.339047188,1 +7.444542326,0.476683375,1 +10.12493903,3.234550982,1 +6.642287351,3.319983761,1 \ No newline at end of file From e8e6f4cc84187470d4b4729354eaed22d521ed04 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Thu, 9 Feb 2017 18:37:19 -0800 Subject: [PATCH 07/13] Wrote out draft of fit, predict, and helper functions. --- src/decision_tree.py | 78 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 25a4761..f2a6795 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -4,14 +4,27 @@ class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" - def __init__(self, split_value, data_set, label=None, left=None, right=None, parent=None): + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, parent=None): """Initialize a node object for a decision tree classifier.""" self.left = left self.right = right - self.parent = parent - self.data_set = data_set + self.data = data self.split_value = split_value - self.label = label + self.split_gini = split_gini + self.split_col = split_col + + def _has_children(self): + """Return True or False if Node has children.""" + if self.right or self.left: + return True + return False + + def _return_children(self): + """Return all children of a Node.""" + if self.left and self.right: + return [self.left, self.right] + elif self.left or self.right: + return [self.left] if self.left else [self.right] class DecisionTree(object): @@ -22,11 +35,53 @@ def __init__(self, max_depth, min_leaf_size): self.max_depth = max_depth self.min_leaf_size = min_leaf_size self.root = None - self.class_values = [] def fit(self, data): """Create a tree to fit the data.""" - pass + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + if not self.root: + self.root = new_node + if self._can_split(split_gini, len(data)): + self.root.left = self.fit(split_groups[0]) + self.root.right = self.fit(split_groups[1]) + else: + return new_node + + # def _build_tree(self, data): + # """Given a node, build the tree.""" + # split_col, split_value, split_gini, split_groups = self._get_split(data) + # new_node = TreeNode(split_value, data, label=split_col) + # if self._can_split(split_gini, len(data)): + # self.root.left = self._build_tree(split_groups[0], new_node) + # self.root.right = self._build_tree(split_groups[1], new_node) + # else: + # return new_node + + def _can_split(self, gini, data_size): + """Given a gini value, determine whether or not tree can split.""" + if gini == 0.0: + return False + elif self._depth() >= self.max_depth: + return False + elif data_size <= self.min_leaf_size: + return False + else: + return True + + def _depth(self, start=''): + """Return the integer depth of the BST.""" + def depth_wrapped(start): + if start is None: + return 0 + else: + right_depth = depth_wrapped(start.right) + left_depth = depth_wrapped(start.left) + return max(right_depth, left_depth) + 1 + if start is '': + return depth_wrapped(self.root) + else: + return depth_wrapped(start) def _calculate_gini(self, groups, class_values): """Calculate gini for a given data_set.""" @@ -53,7 +108,7 @@ def _get_split(self, data): if gini < split_gini: split_col, split_value, split_gini, split_groups =\ col, row[col], gini, groups - return split_col, split_value, split_groups + return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): """Given a dataset, column index, and value, split the dataset.""" @@ -68,4 +123,11 @@ def _test_split(self, col, value, data): def predict(self, data): """Given data, return labels for that data.""" - pass + curr_node = self.root + while curr_node._has_children(): + if data[curr_node.label]: + curr_node = curr_node.right + else: + curr_node = curr_node.left + return curr_node.label + From 772a3e6792be4c351b39992a5cf87974447cd9f3 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Thu, 9 Feb 2017 18:37:21 -0800 Subject: [PATCH 08/13] testing _get_split --- setup.py | 2 +- src/test_decision_tree.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 24d582d..27ad944 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ author_email="", license="MIT", package_dir={'': 'src'}, - py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst"], + py_modules=["linked_list", "stack", "dbl_linked_list", "queue_ds", "deque", "binheap", "graph", "weighted_graph", "bst", "pandas"], install_requires=[], extras_require={"test": ["pytest", "pytest-watch", "pytest-cov", "tox"]}, entry_points={} diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py index e65c19b..1c83f42 100644 --- a/src/test_decision_tree.py +++ b/src/test_decision_tree.py @@ -90,3 +90,14 @@ def test_calculate_gini(): for i in range(len(DATASET2_VALUES)): left, right = dtree._test_split(DATASET2_VALUES[i][1], DATASET2_VALUES[i][0], data) assert round(dtree._calculate_gini([left, right], [0.0, 1.0]), 3) == DATASET2_GINI[i] + + +def test__get_split(): + """Test get optimal split point.""" + from decision_tree import DecisionTree + data_table = pd.DataFrame(DATASET2) + dtree = DecisionTree(1, 1) + split = dtree._get_split(data_table) + # import pdb; pdb.set_trace() + for i in range(len(split[2])): + assert split[2][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file From 5646a18552bad0a1ee3390985b7a648a806d6f70 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 11:15:36 -0800 Subject: [PATCH 09/13] Pushing up tinkering changes and troubleshooting gini splits. --- src/decision_tree.py | 56 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index f2a6795..ea67541 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -4,7 +4,7 @@ class TreeNode(object): """Define a Node object for use in a decision tree classifier.""" - def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, parent=None): + def __init__(self, data, split_value, split_gini, split_col, left=None, right=None, label=None): """Initialize a node object for a decision tree classifier.""" self.left = left self.right = right @@ -12,6 +12,7 @@ def __init__(self, data, split_value, split_gini, split_col, left=None, right=No self.split_value = split_value self.split_gini = split_gini self.split_col = split_col + self.label = label def _has_children(self): """Return True or False if Node has children.""" @@ -19,13 +20,6 @@ def _has_children(self): return True return False - def _return_children(self): - """Return all children of a Node.""" - if self.left and self.right: - return [self.left, self.right] - elif self.left or self.right: - return [self.left] if self.left else [self.right] - class DecisionTree(object): """Define a Decision Tree class object.""" @@ -40,31 +34,37 @@ def fit(self, data): """Create a tree to fit the data.""" split_col, split_value, split_gini, split_groups = self._get_split(data) new_node = TreeNode(data, split_value, split_gini, split_col) - if not self.root: - self.root = new_node - if self._can_split(split_gini, len(data)): - self.root.left = self.fit(split_groups[0]) - self.root.right = self.fit(split_groups[1]) + self.root = new_node + self.root.left = self._build_tree(split_groups[0]) + self.root.right = self._build_tree(split_groups[1]) + + def _build_tree(self, data, depth_count=0): + """Given a node, build the tree.""" + split_col, split_value, split_gini, split_groups = self._get_split(data) + new_node = TreeNode(data, split_value, split_gini, split_col) + try: + new_node.label = data[data.columns[-1]].mode()[0] + except: + pass + if self._can_split(split_gini, depth_count, len(data)): + print("splitting") + new_node.left = self._build_tree(split_groups[0], depth_count + 1) + new_node.right = self._build_tree(split_groups[1], depth_count + 1) else: + print("terminating") return new_node + return new_node - # def _build_tree(self, data): - # """Given a node, build the tree.""" - # split_col, split_value, split_gini, split_groups = self._get_split(data) - # new_node = TreeNode(split_value, data, label=split_col) - # if self._can_split(split_gini, len(data)): - # self.root.left = self._build_tree(split_groups[0], new_node) - # self.root.right = self._build_tree(split_groups[1], new_node) - # else: - # return new_node - - def _can_split(self, gini, data_size): + def _can_split(self, gini, depth_count, data_size): """Given a gini value, determine whether or not tree can split.""" if gini == 0.0: + print("Bad gini") return False - elif self._depth() >= self.max_depth: + elif depth_count >= self.max_depth: + print("bad depth") return False elif data_size <= self.min_leaf_size: + print("bad data size") return False else: return True @@ -105,9 +105,9 @@ def _get_split(self, data): row = data.iloc[i] groups = self._test_split(col, row[col], data) gini = self._calculate_gini(groups, classes) - if gini < split_gini: - split_col, split_value, split_gini, split_groups =\ - col, row[col], gini, groups + if gini < split_gini: + split_col, split_value, split_gini, split_groups =\ + col, row[col], gini, groups return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): From 36f502598ab6a4c9e35d5c87c8b53ba81fead961 Mon Sep 17 00:00:00 2001 From: Ted Callahan Date: Sat, 11 Feb 2017 12:37:00 -0800 Subject: [PATCH 10/13] Troubleshooting weird behavior. Passing the torch. --- src/decision_tree.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index ea67541..88f940c 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -38,7 +38,7 @@ def fit(self, data): self.root.left = self._build_tree(split_groups[0]) self.root.right = self._build_tree(split_groups[1]) - def _build_tree(self, data, depth_count=0): + def _build_tree(self, data, depth_count=1): """Given a node, build the tree.""" split_col, split_value, split_gini, split_groups = self._get_split(data) new_node = TreeNode(data, split_value, split_gini, split_col) @@ -101,24 +101,23 @@ def _get_split(self, data): split_col, split_value, split_gini, split_groups =\ float('inf'), float('inf'), float('inf'), None for col in data.columns.values[:-2]: - for i in range(len(data)): - row = data.iloc[i] - groups = self._test_split(col, row[col], data) + for row in data.iterrows(): + groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) - if gini < split_gini: + if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: split_col, split_value, split_gini, split_groups =\ - col, row[col], gini, groups + col, row[1][col], gini, groups + # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data): """Given a dataset, column index, and value, split the dataset.""" left, right = pd.DataFrame(columns=data.columns), pd.DataFrame(columns=data.columns) - for i in range(len(data)): - row = data.iloc[i] - if row[col] < value: - left = left.append(row) + for row in data.iterrows(): + if row[1][col] < value: + left = left.append(row[1]) else: - right = right.append(row) + right = right.append(row[1]) return left, right def predict(self, data): From 1f0c6563618d1ba2c3baa9cb9f223018d2804d20 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Sat, 11 Feb 2017 14:57:07 -0800 Subject: [PATCH 11/13] more attempted debugging --- src/decision_tree.py | 11 ++++++----- src/test_decision_tree.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index 88f940c..e41fe6b 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -58,7 +58,7 @@ def _build_tree(self, data, depth_count=1): def _can_split(self, gini, depth_count, data_size): """Given a gini value, determine whether or not tree can split.""" if gini == 0.0: - print("Bad gini") + print("gini zero") return False elif depth_count >= self.max_depth: print("bad depth") @@ -97,16 +97,17 @@ def _calculate_gini(self, groups, class_values): def _get_split(self, data): """Choose a split point with lowest gini index.""" + if len(data) < 10: + import pdb; pdb.set_trace() classes = data[data.columns[-1]].unique() - split_col, split_value, split_gini, split_groups =\ - float('inf'), float('inf'), float('inf'), None + split_col, split_value, split_gini, split_groups = float('inf'), float('inf'), float('inf'), None for col in data.columns.values[:-2]: for row in data.iterrows(): groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) + # import pdb; pdb.set_trace() if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: - split_col, split_value, split_gini, split_groups =\ - col, row[1][col], gini, groups + split_col, split_value, split_gini, split_groups = col, row[1][col], gini, groups # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) return split_col, split_value, split_gini, split_groups diff --git a/src/test_decision_tree.py b/src/test_decision_tree.py index 1c83f42..b0f956f 100644 --- a/src/test_decision_tree.py +++ b/src/test_decision_tree.py @@ -99,5 +99,5 @@ def test__get_split(): dtree = DecisionTree(1, 1) split = dtree._get_split(data_table) # import pdb; pdb.set_trace() - for i in range(len(split[2])): - assert split[2][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file + for i in range(len(split[3])): + assert split[3][i].to_dict() == dtree._test_split(0, 5, data_table)[i].to_dict() \ No newline at end of file From 5a2e25e31ff713e0518ffb25cc35ae50d7c0e4fa Mon Sep 17 00:00:00 2001 From: pasaunders Date: Sat, 11 Feb 2017 15:41:56 -0800 Subject: [PATCH 12/13] more debugging --- src/decision_tree.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index e41fe6b..b6d0f27 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -48,6 +48,7 @@ def _build_tree(self, data, depth_count=1): pass if self._can_split(split_gini, depth_count, len(data)): print("splitting") + # import pdb; pdb.set_trace() new_node.left = self._build_tree(split_groups[0], depth_count + 1) new_node.right = self._build_tree(split_groups[1], depth_count + 1) else: @@ -97,15 +98,13 @@ def _calculate_gini(self, groups, class_values): def _get_split(self, data): """Choose a split point with lowest gini index.""" - if len(data) < 10: - import pdb; pdb.set_trace() classes = data[data.columns[-1]].unique() split_col, split_value, split_gini, split_groups = float('inf'), float('inf'), float('inf'), None - for col in data.columns.values[:-2]: + for col in data.columns.values[:-1]: for row in data.iterrows(): groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) - # import pdb; pdb.set_trace() + import pdb; pdb.set_trace() if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: split_col, split_value, split_gini, split_groups = col, row[1][col], gini, groups # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) From c2c4a0745b30dcbf7f877c4c37b4ef89c951ac07 Mon Sep 17 00:00:00 2001 From: pasaunders Date: Sat, 11 Feb 2017 17:46:53 -0800 Subject: [PATCH 13/13] debugging finished? --- src/decision_tree.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/decision_tree.py b/src/decision_tree.py index b6d0f27..d36b78b 100644 --- a/src/decision_tree.py +++ b/src/decision_tree.py @@ -32,25 +32,34 @@ def __init__(self, max_depth, min_leaf_size): def fit(self, data): """Create a tree to fit the data.""" + print('building first node with data: ', data) split_col, split_value, split_gini, split_groups = self._get_split(data) + split_col new_node = TreeNode(data, split_value, split_gini, split_col) self.root = new_node + print('building root left with:', split_groups[0]) self.root.left = self._build_tree(split_groups[0]) + print('building root right with:', split_groups[1]) self.root.right = self._build_tree(split_groups[1]) def _build_tree(self, data, depth_count=1): """Given a node, build the tree.""" + print('in build_tree ', data) split_col, split_value, split_gini, split_groups = self._get_split(data) new_node = TreeNode(data, split_value, split_gini, split_col) try: new_node.label = data[data.columns[-1]].mode()[0] except: pass - if self._can_split(split_gini, depth_count, len(data)): + if depth_count >= self.max_depth: + return new_node + if len(data) <= self.min_leaf_size: + return new_node print("splitting") - # import pdb; pdb.set_trace() - new_node.left = self._build_tree(split_groups[0], depth_count + 1) - new_node.right = self._build_tree(split_groups[1], depth_count + 1) + new_node.left = self._build_tree(split_groups[0], depth_count + 1) + new_node.right = self._build_tree(split_groups[1], depth_count + 1) + if split_gini == 0.0: + return new_node else: print("terminating") return new_node @@ -94,20 +103,25 @@ def _calculate_gini(self, groups, class_values): continue proportion = len(group[group[group.columns[-1]] == class_value]) / float(size) gini += (proportion * (1.0 - proportion)) + print('gini value: ', gini) return gini def _get_split(self, data): """Choose a split point with lowest gini index.""" + print('getting split') + print('columns: ', data.columns.values[:-1]) + print('rows: ', data.iterrows()) classes = data[data.columns[-1]].unique() split_col, split_value, split_gini, split_groups = float('inf'), float('inf'), float('inf'), None for col in data.columns.values[:-1]: for row in data.iterrows(): groups = self._test_split(col, row[1][col], data) gini = self._calculate_gini(groups, classes) - import pdb; pdb.set_trace() - if gini < split_gini and len(groups[0]) > 0 and len(groups[1]) > 0: + if gini < split_gini: + print('new gini: ', gini, ' at value: ', row[1][col]) split_col, split_value, split_gini, split_groups = col, row[1][col], gini, groups - # print("Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) + print("Return from get_split. Col: ", split_col, "s_val: ", split_value, "gini: ", split_gini, "\n groups:", split_groups) + # import pdb; pdb.set_trace() return split_col, split_value, split_gini, split_groups def _test_split(self, col, value, data):