google-research

Форк
0
251 строка · 8.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Helper file to run the discover concept algorithm in the toy dataset."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20

21
import itertools
22

23
from absl import app
24
import numpy as np
25
from numpy import inf
26
from numpy.random import seed
27
from scipy.special import comb
28
from tensorflow import keras
29
import tensorflow.compat.v1 as tf
30
from tensorflow.compat.v1 import set_random_seed
31
from tensorflow.keras.activations import sigmoid
32
import tensorflow.keras.backend as K
33
from tensorflow.keras.layers import Input
34
from tensorflow.keras.layers import Lambda
35
from tensorflow.keras.layers import Layer
36
from tensorflow.keras.models import Model
37
from tensorflow.keras.optimizers import Adam
38
from tensorflow.keras.optimizers import SGD
39

40
seed(0)
41
set_random_seed(0)
42

43
# global variables
44
init = keras.initializers.RandomUniform(minval=-0.5, maxval=0.5, seed=None)
45
batch_size = 128
46

47
step = 200
48
min_weight_arr = []
49
min_index_arr = []
50
concept_arr = {}
51

52

53
class Weight(Layer):
54
  """Simple Weight class."""
55

56
  def __init__(self, dim, **kwargs):
57
    self.dim = dim
58
    super(Weight, self).__init__(**kwargs)
59

60
  def build(self, input_shape):
61
    # creates a trainable weight variable for this layer.
62
    self.kernel = self.add_weight(
63
        name='proj', shape=self.dim, initializer=init, trainable=True)
64
    super(Weight, self).build(input_shape)
65

66
  def call(self, x):
67
    return self.kernel
68

69
  def compute_output_shape(self, input_shape):
70
    return self.dim
71

72

73
def reduce_var(x, axis=None, keepdims=False):
74
  """Returns variance of a tensor, alongside the specified axis."""
75
  m = tf.reduce_mean(x, axis=axis, keep_dims=True)
76
  devs_squared = tf.square(x - m)
77
  return tf.reduce_mean(devs_squared, axis=axis, keep_dims=keepdims)
78

79

80
def concept_loss(cov, cov0, i, n_concept, lmbd=5.):
81
  """Creates a concept loss based on reconstruction loss."""
82

83
  def loss(y_true, y_pred):
84
    if i == 0:
85
      return tf.reduce_mean(
86
          tf.keras.backend.binary_crossentropy(y_true, y_pred))
87
    else:
88
      return tf.reduce_mean(
89
          tf.keras.backend.binary_crossentropy(y_true, y_pred)
90
      ) + lmbd * K.mean(cov - np.eye(n_concept)) + lmbd * K.mean(cov0)
91

92
  return loss
93

94

95
def concept_variance(cov, cov0, i, n_concept):
96
  """Creates a concept loss based on reconstruction variance."""
97

98
  def loss(_, y_pred):
99
    if i == 0:
100
      return 1. * tf.reduce_mean(reduce_var(y_pred, axis=0))
101
    else:
102
      return 1. * tf.reduce_mean(reduce_var(y_pred, axis=0)) + 10. * K.mean(
103
          cov - np.eye(n_concept)) + 10. * K.mean(cov0)
104

105
  return loss
106

107

108
def ipca_model(concept_arraynew2,
109
               dense2,
110
               predict,
111
               f_train,
112
               y_train,
113
               f_val,
114
               y_val,
115
               n_concept,
116
               verbose=False,
117
               epochs=20,
118
               metric='binary_accuracy'):
119
  """Returns main function of ipca."""
120
  pool1f_input = Input(shape=(f_train.shape[1],), name='pool1_input')
121
  cluster_input = K.variable(concept_arraynew2)
122
  proj_weight = Weight((f_train.shape[1], n_concept))(pool1f_input)
123
  proj_weight_n = Lambda(lambda x: K.l2_normalize(x, axis=0))(proj_weight)
124
  eye = K.eye(n_concept) * 1e-5
125
  proj_recon_t = Lambda(
126
      lambda x: K.dot(x, tf.linalg.inv(K.dot(K.transpose(x), x) + eye)))(
127
          proj_weight)
128
  proj_recon = Lambda(lambda x: K.dot(K.dot(x[0], x[2]), K.transpose(x[1])))(
129
      [pool1f_input, proj_weight, proj_recon_t])
130
  # proj_recon2 = Lambda(lambda x: x[0] - K.dot(K.dot(x[0],K.dot(x[1],
131
  # tf.linalg.inv(K.dot(K.transpose(x[1]), x[1]) + 1e-5 * K.eye(n_concept)))),
132
  # K.transpose(x[1])))([pool1f_input, proj_weight])
133

134
  cov1 = Lambda(lambda x: K.mean(K.dot(x[0], x[1]), axis=1))(
135
      [cluster_input, proj_weight_n])
136
  cov0 = Lambda(lambda x: x - K.mean(x, axis=0, keepdims=True))(cov1)
137
  cov0_abs = Lambda(lambda x: K.abs(K.l2_normalize(x, axis=0)))(cov0)
138
  cov0_abs_flat = Lambda(lambda x: K.reshape(x, (-1, n_concept)))(cov0_abs)
139
  cov = Lambda(lambda x: K.dot(K.transpose(x), x))(cov0_abs_flat)
140
  fc2_pr = dense2(proj_recon)
141
  softmax_pr = predict(fc2_pr)
142
  # fc2_pr2 = dense2(proj_recon2)
143
  # softmax_pr2 = predict(fc2_pr2)
144

145
  finetuned_model_pr = Model(inputs=pool1f_input, outputs=softmax_pr)
146
  # finetuned_model_pr2 = Model(inputs=pool1f_input, outputs=softmax_pr2)
147
  # finetuned_model_pr2.compile(loss=
148
  #                             concept_loss(cov,cov0_abs,0),
149
  #                             optimizer = sgd(lr=0.),
150
  #                             metrics=['binary_accuracy'])
151
  finetuned_model_pr.layers[-1].activation = sigmoid
152
  print(finetuned_model_pr.layers[-1].activation)
153
  finetuned_model_pr.layers[-1].trainable = False
154
  # finetuned_model_pr2.layers[-1].trainable = False
155
  finetuned_model_pr.layers[-2].trainable = False
156
  finetuned_model_pr.layers[-3].trainable = False
157
  # finetuned_model_pr2.layers[-2].trainable = False
158
  finetuned_model_pr.compile(
159
      loss=concept_loss(cov, cov0_abs, 0, n_concept),
160
      optimizer=Adam(lr=0.001),
161
      metrics=[metric])
162
  # finetuned_model_pr2.compile(
163
  #    loss=concept_variance(cov, cov0_abs, 0),
164
  #    optimizer=SGD(lr=0.0),
165
  #    metrics=['binary_accuracy'])
166

167
  if verbose:
168
    print(finetuned_model_pr.summary())
169
  # finetuned_model_pr2.summary()
170

171
  finetuned_model_pr.fit(
172
      f_train,
173
      y_train,
174
      batch_size=50,
175
      epochs=epochs,
176
      validation_data=(f_val, y_val),
177
      verbose=verbose)
178
  finetuned_model_pr.layers[-1].trainable = False
179
  finetuned_model_pr.layers[-2].trainable = False
180
  finetuned_model_pr.layers[-3].trainable = False
181
  finetuned_model_pr.compile(
182
      loss=concept_loss(cov, cov0_abs, 1, n_concept),
183
      optimizer=Adam(lr=0.001),
184
      metrics=[metric])
185

186
  return finetuned_model_pr  # , finetuned_model_pr2
187

188

189
def ipca_model_shap(dense2, predict, n_concept, input_size, concept_matrix):
190
  """returns model that calculates of SHAP."""
191
  pool1f_input = Input(shape=(input_size,), name='cluster1')
192
  concept_mask = Input(shape=(n_concept,), name='mask')
193
  proj_weight = Weight((input_size, n_concept))(pool1f_input)
194
  concept_mask_r = Lambda(lambda x: K.mean(x, axis=0, keepdims=True))(
195
      concept_mask)
196
  proj_weight_m = Lambda(lambda x: x[0] * x[1])([proj_weight, concept_mask_r])
197
  eye = K.eye(n_concept) * 1e-10
198
  proj_recon_t = Lambda(
199
      lambda x: K.dot(x, tf.linalg.inv(K.dot(K.transpose(x), x) + eye)))(
200
          proj_weight_m)
201
  proj_recon = Lambda(lambda x: K.dot(K.dot(x[0], x[2]), K.transpose(x[1])))(
202
      [pool1f_input, proj_weight_m, proj_recon_t])
203
  fc2_pr = dense2(proj_recon)
204
  softmax_pr = predict(fc2_pr)
205
  finetuned_model_pr = Model(
206
      inputs=[pool1f_input, concept_mask], outputs=softmax_pr)
207
  finetuned_model_pr.compile(
208
      loss='categorical_crossentropy',
209
      optimizer=SGD(lr=0.000),
210
      metrics=['accuracy'])
211
  finetuned_model_pr.summary()
212
  finetuned_model_pr.layers[-7].set_weights([concept_matrix])
213
  return finetuned_model_pr
214

215

216
def get_acc(binary_sample, f_val, y_val_logit, shap_model, verbose=False):
217
  """Returns accuracy."""
218
  acc = shap_model.evaluate(
219
      [f_val, np.tile(np.array(binary_sample), (f_val.shape[0], 1))],
220
      y_val_logit,
221
      verbose=verbose)[1]
222
  return acc
223

224

225
def shap_kernel(n, k):
226
  """Returns kernel of shapley in KernelSHAP."""
227
  return (n-1)*1.0/((n-k)*k*comb(n, k))
228

229

230
def get_shap(nc, f_val, y_val_logit, shap_model, full_acc, null_acc, n_concept):
231
  """Returns ConceptSHAP."""
232
  inputs = list(itertools.product([0, 1], repeat=n_concept))
233
  outputs = [(get_acc(k, f_val, y_val_logit, shap_model)-null_acc)/
234
             (full_acc-null_acc) for k in inputs]
235
  kernel = [shap_kernel(nc, np.sum(ii)) for ii in inputs]
236
  x = np.array(inputs)
237
  y = np.array(outputs)
238
  k = np.array(kernel)
239
  k[k == inf] = 0
240
  xkx = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
241
  xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), y)
242
  expl = np.matmul(np.linalg.pinv(xkx), xky)
243
  return expl
244

245

246
def main(_):
247
  return
248

249

250
if __name__ == '__main__':
251
  app.run(main)
252

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.