From 510fe936d01b8c8ef2c13b6d89d9bd44cdd8665d Mon Sep 17 00:00:00 2001 From: younesdessia Date: Mon, 19 Feb 2024 10:29:58 +0100 Subject: [PATCH] Fix: plot tree --- dectree/dectree.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dectree/dectree.py b/dectree/dectree.py index 04355d1..6ae4e2f 100644 --- a/dectree/dectree.py +++ b/dectree/dectree.py @@ -218,7 +218,7 @@ def progress(self, ndigits: int = 3): return round(nll/self.number_leaves, ndigits) - def plot_data(self, valid_nodes: List[List[int]], + def _to_plot_data(self, valid_nodes: List[List[int]], complete_graph_layout: bool = True): """ Draw decision tree. @@ -312,9 +312,9 @@ def plot_data(self, valid_nodes: List[List[int]], return tree_plot_data - def plot_tree_data(self, valid_nodes, limit=100): + def plot_tree_data(self, valid_nodes, limit=100, complete_graph_layout: bool = True): - datas = self._to_plot_data(valid_nodes) + datas = self._to_plot_data(valid_nodes=valid_nodes, complete_graph_layout=complete_graph_layout) circle_data = [data for data in datas if data['type'] == 'circle'] if limit is not None and len(circle_data) > limit: