google-research
251 строка · 8.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Helper file to run the discover concept algorithm in the toy dataset."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from absl import app
24import numpy as np
25from numpy import inf
26from numpy.random import seed
27from scipy.special import comb
28from tensorflow import keras
29import tensorflow.compat.v1 as tf
30from tensorflow.compat.v1 import set_random_seed
31from tensorflow.keras.activations import sigmoid
32import tensorflow.keras.backend as K
33from tensorflow.keras.layers import Input
34from tensorflow.keras.layers import Lambda
35from tensorflow.keras.layers import Layer
36from tensorflow.keras.models import Model
37from tensorflow.keras.optimizers import Adam
38from tensorflow.keras.optimizers import SGD
39
40seed(0)
41set_random_seed(0)
42
43# global variables
44init = keras.initializers.RandomUniform(minval=-0.5, maxval=0.5, seed=None)
45batch_size = 128
46
47step = 200
48min_weight_arr = []
49min_index_arr = []
50concept_arr = {}
51
52
53class Weight(Layer):
54"""Simple Weight class."""
55
56def __init__(self, dim, **kwargs):
57self.dim = dim
58super(Weight, self).__init__(**kwargs)
59
60def build(self, input_shape):
61# creates a trainable weight variable for this layer.
62self.kernel = self.add_weight(
63name='proj', shape=self.dim, initializer=init, trainable=True)
64super(Weight, self).build(input_shape)
65
66def call(self, x):
67return self.kernel
68
69def compute_output_shape(self, input_shape):
70return self.dim
71
72
73def reduce_var(x, axis=None, keepdims=False):
74"""Returns variance of a tensor, alongside the specified axis."""
75m = tf.reduce_mean(x, axis=axis, keep_dims=True)
76devs_squared = tf.square(x - m)
77return tf.reduce_mean(devs_squared, axis=axis, keep_dims=keepdims)
78
79
80def concept_loss(cov, cov0, i, n_concept, lmbd=5.):
81"""Creates a concept loss based on reconstruction loss."""
82
83def loss(y_true, y_pred):
84if i == 0:
85return tf.reduce_mean(
86tf.keras.backend.binary_crossentropy(y_true, y_pred))
87else:
88return tf.reduce_mean(
89tf.keras.backend.binary_crossentropy(y_true, y_pred)
90) + lmbd * K.mean(cov - np.eye(n_concept)) + lmbd * K.mean(cov0)
91
92return loss
93
94
95def concept_variance(cov, cov0, i, n_concept):
96"""Creates a concept loss based on reconstruction variance."""
97
98def loss(_, y_pred):
99if i == 0:
100return 1. * tf.reduce_mean(reduce_var(y_pred, axis=0))
101else:
102return 1. * tf.reduce_mean(reduce_var(y_pred, axis=0)) + 10. * K.mean(
103cov - np.eye(n_concept)) + 10. * K.mean(cov0)
104
105return loss
106
107
108def ipca_model(concept_arraynew2,
109dense2,
110predict,
111f_train,
112y_train,
113f_val,
114y_val,
115n_concept,
116verbose=False,
117epochs=20,
118metric='binary_accuracy'):
119"""Returns main function of ipca."""
120pool1f_input = Input(shape=(f_train.shape[1],), name='pool1_input')
121cluster_input = K.variable(concept_arraynew2)
122proj_weight = Weight((f_train.shape[1], n_concept))(pool1f_input)
123proj_weight_n = Lambda(lambda x: K.l2_normalize(x, axis=0))(proj_weight)
124eye = K.eye(n_concept) * 1e-5
125proj_recon_t = Lambda(
126lambda x: K.dot(x, tf.linalg.inv(K.dot(K.transpose(x), x) + eye)))(
127proj_weight)
128proj_recon = Lambda(lambda x: K.dot(K.dot(x[0], x[2]), K.transpose(x[1])))(
129[pool1f_input, proj_weight, proj_recon_t])
130# proj_recon2 = Lambda(lambda x: x[0] - K.dot(K.dot(x[0],K.dot(x[1],
131# tf.linalg.inv(K.dot(K.transpose(x[1]), x[1]) + 1e-5 * K.eye(n_concept)))),
132# K.transpose(x[1])))([pool1f_input, proj_weight])
133
134cov1 = Lambda(lambda x: K.mean(K.dot(x[0], x[1]), axis=1))(
135[cluster_input, proj_weight_n])
136cov0 = Lambda(lambda x: x - K.mean(x, axis=0, keepdims=True))(cov1)
137cov0_abs = Lambda(lambda x: K.abs(K.l2_normalize(x, axis=0)))(cov0)
138cov0_abs_flat = Lambda(lambda x: K.reshape(x, (-1, n_concept)))(cov0_abs)
139cov = Lambda(lambda x: K.dot(K.transpose(x), x))(cov0_abs_flat)
140fc2_pr = dense2(proj_recon)
141softmax_pr = predict(fc2_pr)
142# fc2_pr2 = dense2(proj_recon2)
143# softmax_pr2 = predict(fc2_pr2)
144
145finetuned_model_pr = Model(inputs=pool1f_input, outputs=softmax_pr)
146# finetuned_model_pr2 = Model(inputs=pool1f_input, outputs=softmax_pr2)
147# finetuned_model_pr2.compile(loss=
148# concept_loss(cov,cov0_abs,0),
149# optimizer = sgd(lr=0.),
150# metrics=['binary_accuracy'])
151finetuned_model_pr.layers[-1].activation = sigmoid
152print(finetuned_model_pr.layers[-1].activation)
153finetuned_model_pr.layers[-1].trainable = False
154# finetuned_model_pr2.layers[-1].trainable = False
155finetuned_model_pr.layers[-2].trainable = False
156finetuned_model_pr.layers[-3].trainable = False
157# finetuned_model_pr2.layers[-2].trainable = False
158finetuned_model_pr.compile(
159loss=concept_loss(cov, cov0_abs, 0, n_concept),
160optimizer=Adam(lr=0.001),
161metrics=[metric])
162# finetuned_model_pr2.compile(
163# loss=concept_variance(cov, cov0_abs, 0),
164# optimizer=SGD(lr=0.0),
165# metrics=['binary_accuracy'])
166
167if verbose:
168print(finetuned_model_pr.summary())
169# finetuned_model_pr2.summary()
170
171finetuned_model_pr.fit(
172f_train,
173y_train,
174batch_size=50,
175epochs=epochs,
176validation_data=(f_val, y_val),
177verbose=verbose)
178finetuned_model_pr.layers[-1].trainable = False
179finetuned_model_pr.layers[-2].trainable = False
180finetuned_model_pr.layers[-3].trainable = False
181finetuned_model_pr.compile(
182loss=concept_loss(cov, cov0_abs, 1, n_concept),
183optimizer=Adam(lr=0.001),
184metrics=[metric])
185
186return finetuned_model_pr # , finetuned_model_pr2
187
188
189def ipca_model_shap(dense2, predict, n_concept, input_size, concept_matrix):
190"""returns model that calculates of SHAP."""
191pool1f_input = Input(shape=(input_size,), name='cluster1')
192concept_mask = Input(shape=(n_concept,), name='mask')
193proj_weight = Weight((input_size, n_concept))(pool1f_input)
194concept_mask_r = Lambda(lambda x: K.mean(x, axis=0, keepdims=True))(
195concept_mask)
196proj_weight_m = Lambda(lambda x: x[0] * x[1])([proj_weight, concept_mask_r])
197eye = K.eye(n_concept) * 1e-10
198proj_recon_t = Lambda(
199lambda x: K.dot(x, tf.linalg.inv(K.dot(K.transpose(x), x) + eye)))(
200proj_weight_m)
201proj_recon = Lambda(lambda x: K.dot(K.dot(x[0], x[2]), K.transpose(x[1])))(
202[pool1f_input, proj_weight_m, proj_recon_t])
203fc2_pr = dense2(proj_recon)
204softmax_pr = predict(fc2_pr)
205finetuned_model_pr = Model(
206inputs=[pool1f_input, concept_mask], outputs=softmax_pr)
207finetuned_model_pr.compile(
208loss='categorical_crossentropy',
209optimizer=SGD(lr=0.000),
210metrics=['accuracy'])
211finetuned_model_pr.summary()
212finetuned_model_pr.layers[-7].set_weights([concept_matrix])
213return finetuned_model_pr
214
215
216def get_acc(binary_sample, f_val, y_val_logit, shap_model, verbose=False):
217"""Returns accuracy."""
218acc = shap_model.evaluate(
219[f_val, np.tile(np.array(binary_sample), (f_val.shape[0], 1))],
220y_val_logit,
221verbose=verbose)[1]
222return acc
223
224
225def shap_kernel(n, k):
226"""Returns kernel of shapley in KernelSHAP."""
227return (n-1)*1.0/((n-k)*k*comb(n, k))
228
229
230def get_shap(nc, f_val, y_val_logit, shap_model, full_acc, null_acc, n_concept):
231"""Returns ConceptSHAP."""
232inputs = list(itertools.product([0, 1], repeat=n_concept))
233outputs = [(get_acc(k, f_val, y_val_logit, shap_model)-null_acc)/
234(full_acc-null_acc) for k in inputs]
235kernel = [shap_kernel(nc, np.sum(ii)) for ii in inputs]
236x = np.array(inputs)
237y = np.array(outputs)
238k = np.array(kernel)
239k[k == inf] = 0
240xkx = np.matmul(np.matmul(x.transpose(), np.diag(k)), x)
241xky = np.matmul(np.matmul(x.transpose(), np.diag(k)), y)
242expl = np.matmul(np.linalg.pinv(xkx), xky)
243return expl
244
245
246def main(_):
247return
248
249
250if __name__ == '__main__':
251app.run(main)
252