google-research

Форк
0
74 строки · 2.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Main file to run AwA experiments."""
17
import os
18
import awa_helper
19
import ipca
20

21
DATA_DIR = '/mnt/disks/mndir/Animals_with_Attributes2/'
22
TRAIN_DIR = os.path.join(DATA_DIR, 'train/')
23
VALID_DIR = os.path.join(DATA_DIR, 'val/')
24
SIZE = (299, 299)
25
BATCH_SIZE = 64
26
n_concept = 8
27
pretrained = False
28

29
if __name__ == '__main__':
30

31
  y_train_logit, y_val_logit, y_train, \
32
      y_val, f_train, f_val, dense2, predict = awa_helper.load_data(TRAIN_DIR,
33
                                                                    SIZE,
34
                                                                    BATCH_SIZE,
35
                                                                    pretrained,
36
                                                                    noise=0.0)
37

38
  concept_arraynew, concept_arraynew_active, \
39
      concept_list, active_list = awa_helper.load_conceptarray()
40

41
  finetuned_model_pr = ipca.ipca_model(
42
      concept_arraynew_active,
43
      dense2,
44
      predict,
45
      f_train,
46
      y_train_logit,
47
      f_val,
48
      y_val_logit,
49
      n_concept,
50
      verbose=True,
51
      epochs=150,
52
      metric='accuracy')
53

54
  num_epoch = 50
55
  for _ in range(num_epoch):
56
    finetuned_model_pr.fit(
57
        f_train,
58
        y_train_logit,
59
        batch_size=100,
60
        epochs=10,
61
        verbose=1,
62
        validation_data=(f_val, y_val_logit))
63

64
  concept_matrix = finetuned_model_pr.layers[-5].get_weights()[0]
65

66
  # Plots nearest neighbors in each cluster for each concept.
67
  awa_helper.plot_nearestneighbors(concept_arraynew_active, concept_matrix,
68
                                   concept_list, active_list)
69

70
  # Calculates conceptSHAP.
71
  shap_model = ipca.ipca_model_shap(dense2, predict, n_concept, 1024,
72
                                    concept_matrix)
73

74
  print(ipca.get_shap(n_concept, f_val, y_val_logit, shap_model, 0.94, 0.019))
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.