FEDOT

Форк
0
/
surrogate_explainer.py 
161 строка · 6.3 Кб
1
import os
2
from copy import deepcopy
3
from inspect import signature
4
from typing import Optional
5

6
from matplotlib import pyplot as plt
7
from sklearn import tree
8
from sklearn.tree._tree import TREE_LEAF
9

10
from fedot.core.composer.metrics import Metric
11
from fedot.core.composer.metrics import R2, F1
12
from fedot.core.data.data import InputData
13
from fedot.core.pipelines.node import PipelineNode
14
from fedot.core.pipelines.pipeline import Pipeline
15
from fedot.core.repository.tasks import TaskTypesEnum
16
from fedot.explainability.explainer_template import Explainer
17

18

19
class SurrogateExplainer(Explainer):
20
    """
21
    Base class used for composite model structure definition
22

23
    :param model: `Pipeline` object to be explained
24
    :param surrogate: surrogate name. Supported surrogates: `[dt, dtreg]`
25

26
    .. note::
27
        `score` stores the score of surrogate's prediction on model (equals None if the 'explain' method hasn't been
28
        called yet)
29
    """
30

31
    surrogates_default_params = {
32
        'dt': {'max_depth': 5},
33
        'dtreg': {'max_depth': 5},
34
    }
35

36
    def __init__(self, model: 'Pipeline', surrogate: str):
37
        super().__init__(model)
38

39
        self.score: Optional[float] = None
40

41
        if not isinstance(surrogate, str):
42
            raise ValueError(f'{surrogate} is not supported as a surrogate model')
43
        if surrogate not in self.surrogates_default_params:
44
            raise ValueError(f'{type(surrogate)} is not supported as a surrogate model')
45

46
        self.surrogate_str = surrogate
47
        self.surrogate = get_simple_pipeline(self.surrogate_str, self.surrogates_default_params[surrogate],
48
                                             model.use_input_preprocessing)
49

50
    def explain(self, data: InputData, visualization: bool = False, **kwargs):
51
        try:
52
            self.score = fit_naive_surrogate_model(self.model, self.surrogate, data)
53

54
        except Exception as ex:
55
            print(f'Failed to fit the surrogate: {ex}')
56
            return
57

58
        # Pruning redundant branches and leaves
59
        if self.surrogate_str in ('dt', 'dtreg'):
60
            prune_duplicate_leaves(self.surrogate.root_node.fitted_operation)
61

62
        if visualization:
63
            self.visualize(**kwargs)
64

65
    def visualize(self, dpi: int = 100, figsize=(48, 12), save_path: str = None, **kwargs):
66
        """Print and plot results of the last explanation. Suitable keyword parameters
67
        are passed to the corresponding plot function.
68
        :param dpi: the figure DPI, defaults to 100.
69
        :param figsize: the figure size in format `(width, height)`, defaults to `(48, 12)`.
70
        :param save_path: path to save the plot.
71
        """
72
        plt.figure(dpi=dpi, figsize=figsize)
73
        if self.surrogate_str in ['dt', 'dtreg']:
74

75
            if self.score is not None:
76
                print(f'Surrogate\'s model reproduction quality: {self.score}')
77
            # Plot default parameters
78
            plot_params = {
79
                'proportion': True,
80
                'filled': True,
81
                'rounded': True,
82
                'fontsize': 12,
83
            }
84
            # Plot parameters defined by user
85
            kwargs_params = \
86
                {par: kwargs[par] for par in kwargs if par in signature(tree.plot_tree).parameters}
87

88
            plot_params.update(kwargs_params)
89

90
            tree.plot_tree(self.surrogate.root_node.fitted_operation, **plot_params)
91

92
        if save_path is not None:
93
            plt.savefig(save_path)
94
            print(f'Saved the plot to "{os.path.abspath(save_path)}"')
95

96

97
def get_simple_pipeline(model: str, custom_params: dict = None,
98
                        use_input_preprocessing: bool = True) -> 'Pipeline':
99
    surrogate_node = PipelineNode(model)
100
    if custom_params:
101
        surrogate_node.parameters = custom_params
102
    return Pipeline(surrogate_node, use_input_preprocessing=use_input_preprocessing)
103

104

105
def fit_naive_surrogate_model(
106
        black_box_model: 'Pipeline', surrogate_model: 'Pipeline', data: 'InputData',
107
        metric: 'Metric' = None) -> Optional[float]:
108
    output_mode = 'default'
109

110
    if data.task.task_type == TaskTypesEnum.classification:
111
        output_mode = 'labels'
112
        if metric is None:
113
            metric = F1
114
    elif data.task.task_type == TaskTypesEnum.regression and metric is None:
115
        metric = R2
116

117
    prediction = black_box_model.predict(data, output_mode=output_mode)
118
    data.target = prediction.predict
119
    surrogate_model.fit(data)
120

121
    data_c = deepcopy(data)
122
    data_c.target = surrogate_model.predict(data, output_mode=output_mode).predict
123
    score = round(abs(metric.metric(data_c, prediction)), 2)
124

125
    return score
126

127

128
def is_leaf(inner_tree, index):
129
    # Check whether node is leaf node
130
    return (inner_tree.children_left[index] == TREE_LEAF and
131
            inner_tree.children_right[index] == TREE_LEAF)
132

133

134
def prune_index(inner_tree, decisions, index=0):
135
    # Start pruning from the bottom - if we start from the top, we might miss
136
    # nodes that become leaves during pruning.
137
    # Do not use this directly - use prune_duplicate_leaves instead.
138
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
139
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
140
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
141
        prune_index(inner_tree, decisions, inner_tree.children_right[index])
142

143
    # Prune children if both children are leaves now and make the same decision:
144
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
145
            is_leaf(inner_tree, inner_tree.children_right[index]) and
146
            (decisions[index] == decisions[inner_tree.children_left[index]]) and
147
            (decisions[index] == decisions[inner_tree.children_right[index]])):
148
        # turn node into a leaf by "unlinking" its children
149
        inner_tree.children_left[index] = TREE_LEAF
150
        inner_tree.children_right[index] = TREE_LEAF
151

152

153
def prune_duplicate_leaves(mdl):
154
    """
155
    Function for pruning redundant leaves of a tree by Thomas (https://stackoverflow.com/users/4629950/thomas).
156
    Source: https://stackoverflow.com/questions/51397109/prune-unnecessary-leaves-in-sklearn-decisiontreeclassifier
157
    :param mdl: `DecisionTree` or `DecisionTreeRegressor` instance by sklearn.
158
    """
159
    # Remove leaves if both
160
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist()  # Decision for each node
161
    prune_index(mdl.tree_, decisions)
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.