FEDOT
161 строка · 6.3 Кб
1import os
2from copy import deepcopy
3from inspect import signature
4from typing import Optional
5
6from matplotlib import pyplot as plt
7from sklearn import tree
8from sklearn.tree._tree import TREE_LEAF
9
10from fedot.core.composer.metrics import Metric
11from fedot.core.composer.metrics import R2, F1
12from fedot.core.data.data import InputData
13from fedot.core.pipelines.node import PipelineNode
14from fedot.core.pipelines.pipeline import Pipeline
15from fedot.core.repository.tasks import TaskTypesEnum
16from fedot.explainability.explainer_template import Explainer
17
18
19class SurrogateExplainer(Explainer):
20"""
21Base class used for composite model structure definition
22
23:param model: `Pipeline` object to be explained
24:param surrogate: surrogate name. Supported surrogates: `[dt, dtreg]`
25
26.. note::
27`score` stores the score of surrogate's prediction on model (equals None if the 'explain' method hasn't been
28called yet)
29"""
30
31surrogates_default_params = {
32'dt': {'max_depth': 5},
33'dtreg': {'max_depth': 5},
34}
35
36def __init__(self, model: 'Pipeline', surrogate: str):
37super().__init__(model)
38
39self.score: Optional[float] = None
40
41if not isinstance(surrogate, str):
42raise ValueError(f'{surrogate} is not supported as a surrogate model')
43if surrogate not in self.surrogates_default_params:
44raise ValueError(f'{type(surrogate)} is not supported as a surrogate model')
45
46self.surrogate_str = surrogate
47self.surrogate = get_simple_pipeline(self.surrogate_str, self.surrogates_default_params[surrogate],
48model.use_input_preprocessing)
49
50def explain(self, data: InputData, visualization: bool = False, **kwargs):
51try:
52self.score = fit_naive_surrogate_model(self.model, self.surrogate, data)
53
54except Exception as ex:
55print(f'Failed to fit the surrogate: {ex}')
56return
57
58# Pruning redundant branches and leaves
59if self.surrogate_str in ('dt', 'dtreg'):
60prune_duplicate_leaves(self.surrogate.root_node.fitted_operation)
61
62if visualization:
63self.visualize(**kwargs)
64
65def visualize(self, dpi: int = 100, figsize=(48, 12), save_path: str = None, **kwargs):
66"""Print and plot results of the last explanation. Suitable keyword parameters
67are passed to the corresponding plot function.
68:param dpi: the figure DPI, defaults to 100.
69:param figsize: the figure size in format `(width, height)`, defaults to `(48, 12)`.
70:param save_path: path to save the plot.
71"""
72plt.figure(dpi=dpi, figsize=figsize)
73if self.surrogate_str in ['dt', 'dtreg']:
74
75if self.score is not None:
76print(f'Surrogate\'s model reproduction quality: {self.score}')
77# Plot default parameters
78plot_params = {
79'proportion': True,
80'filled': True,
81'rounded': True,
82'fontsize': 12,
83}
84# Plot parameters defined by user
85kwargs_params = \
86{par: kwargs[par] for par in kwargs if par in signature(tree.plot_tree).parameters}
87
88plot_params.update(kwargs_params)
89
90tree.plot_tree(self.surrogate.root_node.fitted_operation, **plot_params)
91
92if save_path is not None:
93plt.savefig(save_path)
94print(f'Saved the plot to "{os.path.abspath(save_path)}"')
95
96
97def get_simple_pipeline(model: str, custom_params: dict = None,
98use_input_preprocessing: bool = True) -> 'Pipeline':
99surrogate_node = PipelineNode(model)
100if custom_params:
101surrogate_node.parameters = custom_params
102return Pipeline(surrogate_node, use_input_preprocessing=use_input_preprocessing)
103
104
105def fit_naive_surrogate_model(
106black_box_model: 'Pipeline', surrogate_model: 'Pipeline', data: 'InputData',
107metric: 'Metric' = None) -> Optional[float]:
108output_mode = 'default'
109
110if data.task.task_type == TaskTypesEnum.classification:
111output_mode = 'labels'
112if metric is None:
113metric = F1
114elif data.task.task_type == TaskTypesEnum.regression and metric is None:
115metric = R2
116
117prediction = black_box_model.predict(data, output_mode=output_mode)
118data.target = prediction.predict
119surrogate_model.fit(data)
120
121data_c = deepcopy(data)
122data_c.target = surrogate_model.predict(data, output_mode=output_mode).predict
123score = round(abs(metric.metric(data_c, prediction)), 2)
124
125return score
126
127
128def is_leaf(inner_tree, index):
129# Check whether node is leaf node
130return (inner_tree.children_left[index] == TREE_LEAF and
131inner_tree.children_right[index] == TREE_LEAF)
132
133
134def prune_index(inner_tree, decisions, index=0):
135# Start pruning from the bottom - if we start from the top, we might miss
136# nodes that become leaves during pruning.
137# Do not use this directly - use prune_duplicate_leaves instead.
138if not is_leaf(inner_tree, inner_tree.children_left[index]):
139prune_index(inner_tree, decisions, inner_tree.children_left[index])
140if not is_leaf(inner_tree, inner_tree.children_right[index]):
141prune_index(inner_tree, decisions, inner_tree.children_right[index])
142
143# Prune children if both children are leaves now and make the same decision:
144if (is_leaf(inner_tree, inner_tree.children_left[index]) and
145is_leaf(inner_tree, inner_tree.children_right[index]) and
146(decisions[index] == decisions[inner_tree.children_left[index]]) and
147(decisions[index] == decisions[inner_tree.children_right[index]])):
148# turn node into a leaf by "unlinking" its children
149inner_tree.children_left[index] = TREE_LEAF
150inner_tree.children_right[index] = TREE_LEAF
151
152
153def prune_duplicate_leaves(mdl):
154"""
155Function for pruning redundant leaves of a tree by Thomas (https://stackoverflow.com/users/4629950/thomas).
156Source: https://stackoverflow.com/questions/51397109/prune-unnecessary-leaves-in-sklearn-decisiontreeclassifier
157:param mdl: `DecisionTree` or `DecisionTreeRegressor` instance by sklearn.
158"""
159# Remove leaves if both
160decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
161prune_index(mdl.tree_, decisions)
162