diff --git a/dectree/dectree.py b/dectree/dectree.py index 04355d1..6ae4e2f 100644 --- a/dectree/dectree.py +++ b/dectree/dectree.py @@ -218,7 +218,7 @@ def progress(self, ndigits: int = 3): return round(nll/self.number_leaves, ndigits) - def plot_data(self, valid_nodes: List[List[int]], + def _to_plot_data(self, valid_nodes: List[List[int]], complete_graph_layout: bool = True): """ Draw decision tree. @@ -312,9 +312,9 @@ def plot_data(self, valid_nodes: List[List[int]], return tree_plot_data - def plot_tree_data(self, valid_nodes, limit=100): + def plot_tree_data(self, valid_nodes, limit=100, complete_graph_layout: bool = True): - datas = self._to_plot_data(valid_nodes) + datas = self._to_plot_data(valid_nodes=valid_nodes, complete_graph_layout=complete_graph_layout) circle_data = [data for data in datas if data['type'] == 'circle'] if limit is not None and len(circle_data) > limit: