google-research
74 строки · 2.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Main file to run AwA experiments."""
17import os18import awa_helper19import ipca20
21DATA_DIR = '/mnt/disks/mndir/Animals_with_Attributes2/'22TRAIN_DIR = os.path.join(DATA_DIR, 'train/')23VALID_DIR = os.path.join(DATA_DIR, 'val/')24SIZE = (299, 299)25BATCH_SIZE = 6426n_concept = 827pretrained = False28
29if __name__ == '__main__':30
31y_train_logit, y_val_logit, y_train, \32y_val, f_train, f_val, dense2, predict = awa_helper.load_data(TRAIN_DIR,33SIZE,34BATCH_SIZE,35pretrained,36noise=0.0)37
38concept_arraynew, concept_arraynew_active, \39concept_list, active_list = awa_helper.load_conceptarray()40
41finetuned_model_pr = ipca.ipca_model(42concept_arraynew_active,43dense2,44predict,45f_train,46y_train_logit,47f_val,48y_val_logit,49n_concept,50verbose=True,51epochs=150,52metric='accuracy')53
54num_epoch = 5055for _ in range(num_epoch):56finetuned_model_pr.fit(57f_train,58y_train_logit,59batch_size=100,60epochs=10,61verbose=1,62validation_data=(f_val, y_val_logit))63
64concept_matrix = finetuned_model_pr.layers[-5].get_weights()[0]65
66# Plots nearest neighbors in each cluster for each concept.67awa_helper.plot_nearestneighbors(concept_arraynew_active, concept_matrix,68concept_list, active_list)69
70# Calculates conceptSHAP.71shap_model = ipca.ipca_model_shap(dense2, predict, n_concept, 1024,72concept_matrix)73
74print(ipca.get_shap(n_concept, f_val, y_val_logit, shap_model, 0.94, 0.019))75