google-research

Форк
0
213 строк · 5.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""This is the code to run the discover concept algorithm in the toy dataset."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20
from absl import app
21
import ipca
22
import numpy as np
23
import toy_helper
24

25

26
def main(_):
27
  n_concept = 5
28
  n_cluster = 5
29
  n = 60000
30
  n0 = int(n * 0.8)
31

32
  pretrain = True
33
  # Loads data.
34
  x, y, concept = toy_helper.load_xyconcept(n, pretrain)
35
  if not pretrain:
36
    x_train = x[:n0, :]
37
    x_val = x[n0:, :]
38
  y_train = y[:n0, :]
39
  y_val = y[n0:, :]
40
  all_feature_dense = np.load('all_feature_dense.npy')
41
  f_train = all_feature_dense[:n0, :]
42
  f_val = all_feature_dense[n0:, :]
43
  # Loads model.
44
  if not pretrain:
45
    dense2, predict, _ = toy_helper.load_model(
46
        x_train, y_train, x_val, y_val, pretrain=pretrain)
47
  else:
48
    dense2, predict, _ = toy_helper.load_model(_, _, _, _, pretrain=pretrain)
49
  # Loads concepts.
50
  concept_arraynew = np.load('concept_arraynew.npy')
51
  concept_arraynew2 = np.load('concept_arraynew2.npy')
52

53
  for n_concept in range(1, 10):
54
    print(n_concept)
55
    # Discovers concept with true cluster.
56
    finetuned_model_pr = ipca.ipca_model(
57
        concept_arraynew2,
58
        dense2,
59
        predict,
60
        f_train,
61
        y_train,
62
        f_val,
63
        y_val,
64
        n_concept,
65
        comp1=True)
66
    num_epoch = 5
67
    for _ in range(num_epoch):
68
      finetuned_model_pr.fit(
69
          f_train,
70
          y_train,
71
          batch_size=50,
72
          epochs=10,
73
          verbose=True,
74
          validation_data=(f_val, y_val))
75
    # Evaluates groupacc and get concept_matrix.
76
    _, _ = ipca.get_groupacc(
77
        finetuned_model_pr,
78
        concept_arraynew2,
79
        f_train,
80
        f_val,
81
        concept,
82
        n_concept,
83
        n_cluster,
84
        n0,
85
        verbose=False)
86
    # Discovers concepts with self-discovered clusters.
87
    finetuned_model_pr = ipca.ipca_model(
88
        concept_arraynew,
89
        dense2,
90
        predict,
91
        f_train,
92
        y_train,
93
        f_val,
94
        y_val,
95
        n_concept,
96
        comp1=True)
97
    num_epoch = 5
98
    for _ in range(num_epoch):
99
      finetuned_model_pr.fit(
100
          f_train,
101
          y_train,
102
          batch_size=50,
103
          epochs=10,
104
          verbose=True,
105
          validation_data=(f_val, y_val))
106
    _, _ = ipca.get_groupacc(
107
        finetuned_model_pr,
108
        concept_arraynew,
109
        f_train,
110
        f_val,
111
        concept,
112
        n_concept,
113
        n_cluster,
114
        n0,
115
        verbose=False)
116

117
  for n_concept in range(1, 10):
118
    print(n_concept)
119
    concept_matrix_ace = toy_helper.get_ace_concept(concept_arraynew, dense2,
120
                                                    predict, f_val, n_concept)
121

122
    finetuned_model_pr_ace = ipca.ipca_model(
123
        concept_arraynew,
124
        dense2,
125
        predict,
126
        f_train,
127
        y_train,
128
        f_val,
129
        y_val,
130
        n_concept,
131
        verbose=True,
132
        epochs=0,
133
        metric='accuracy')
134

135
    finetuned_model_pr_ace.layers[-5].set_weights([concept_matrix_ace])
136

137
    print(finetuned_model_pr_ace.evaluate(f_val, y_val))
138

139
    _, _ = ipca.get_groupacc(
140
        finetuned_model_pr_ace,
141
        concept_arraynew,
142
        f_train,
143
        f_val,
144
        concept,
145
        n_concept,
146
        n_cluster,
147
        n0,
148
        verbose=False)
149

150
    concept_matrix_ace2 = toy_helper.get_ace_concept(concept_arraynew2, dense2,
151
                                                     predict, f_val, n_concept)
152

153
    finetuned_model_pr_ace2 = ipca.ipca_model(
154
        concept_arraynew2,
155
        dense2,
156
        predict,
157
        f_train,
158
        y_train,
159
        f_val,
160
        y_val,
161
        n_concept,
162
        verbose=True,
163
        epochs=0,
164
        metric='accuracy')
165

166
    finetuned_model_pr_ace2.layers[-5].set_weights([concept_matrix_ace2])
167

168
    print(finetuned_model_pr_ace2.evaluate(f_val, y_val))
169

170
    _, _ = ipca.get_groupacc(
171
        finetuned_model_pr_ace2,
172
        concept_arraynew2,
173
        f_train,
174
        f_val,
175
        concept,
176
        n_concept,
177
        n_cluster,
178
        n0,
179
        verbose=False)
180

181
    concept_matrix_pca = toy_helper.get_pca_concept(f_train, n_concept)
182

183
    finetuned_model_pr_pca = ipca.ipca_model(
184
        concept_arraynew,
185
        dense2,
186
        predict,
187
        f_train,
188
        y_train,
189
        f_val,
190
        y_val,
191
        n_concept,
192
        verbose=True,
193
        epochs=0,
194
        metric='accuracy')
195

196
    finetuned_model_pr_pca.layers[-5].set_weights([concept_matrix_pca])
197

198
    print(finetuned_model_pr_pca.evaluate(f_val, y_val))
199

200
    _, _ = ipca.get_groupacc(
201
        finetuned_model_pr_pca,
202
        concept_arraynew,
203
        f_train,
204
        f_val,
205
        concept,
206
        n_concept,
207
        n_cluster,
208
        n0,
209
        verbose=False)
210

211

212
if __name__ == '__main__':
213
  app.run(main)
214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.