google-research
213 строк · 5.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""This is the code to run the discover concept algorithm in the toy dataset."""
17from __future__ import absolute_import18from __future__ import division19from __future__ import print_function20from absl import app21import ipca22import numpy as np23import toy_helper24
25
26def main(_):27n_concept = 528n_cluster = 529n = 6000030n0 = int(n * 0.8)31
32pretrain = True33# Loads data.34x, y, concept = toy_helper.load_xyconcept(n, pretrain)35if not pretrain:36x_train = x[:n0, :]37x_val = x[n0:, :]38y_train = y[:n0, :]39y_val = y[n0:, :]40all_feature_dense = np.load('all_feature_dense.npy')41f_train = all_feature_dense[:n0, :]42f_val = all_feature_dense[n0:, :]43# Loads model.44if not pretrain:45dense2, predict, _ = toy_helper.load_model(46x_train, y_train, x_val, y_val, pretrain=pretrain)47else:48dense2, predict, _ = toy_helper.load_model(_, _, _, _, pretrain=pretrain)49# Loads concepts.50concept_arraynew = np.load('concept_arraynew.npy')51concept_arraynew2 = np.load('concept_arraynew2.npy')52
53for n_concept in range(1, 10):54print(n_concept)55# Discovers concept with true cluster.56finetuned_model_pr = ipca.ipca_model(57concept_arraynew2,58dense2,59predict,60f_train,61y_train,62f_val,63y_val,64n_concept,65comp1=True)66num_epoch = 567for _ in range(num_epoch):68finetuned_model_pr.fit(69f_train,70y_train,71batch_size=50,72epochs=10,73verbose=True,74validation_data=(f_val, y_val))75# Evaluates groupacc and get concept_matrix.76_, _ = ipca.get_groupacc(77finetuned_model_pr,78concept_arraynew2,79f_train,80f_val,81concept,82n_concept,83n_cluster,84n0,85verbose=False)86# Discovers concepts with self-discovered clusters.87finetuned_model_pr = ipca.ipca_model(88concept_arraynew,89dense2,90predict,91f_train,92y_train,93f_val,94y_val,95n_concept,96comp1=True)97num_epoch = 598for _ in range(num_epoch):99finetuned_model_pr.fit(100f_train,101y_train,102batch_size=50,103epochs=10,104verbose=True,105validation_data=(f_val, y_val))106_, _ = ipca.get_groupacc(107finetuned_model_pr,108concept_arraynew,109f_train,110f_val,111concept,112n_concept,113n_cluster,114n0,115verbose=False)116
117for n_concept in range(1, 10):118print(n_concept)119concept_matrix_ace = toy_helper.get_ace_concept(concept_arraynew, dense2,120predict, f_val, n_concept)121
122finetuned_model_pr_ace = ipca.ipca_model(123concept_arraynew,124dense2,125predict,126f_train,127y_train,128f_val,129y_val,130n_concept,131verbose=True,132epochs=0,133metric='accuracy')134
135finetuned_model_pr_ace.layers[-5].set_weights([concept_matrix_ace])136
137print(finetuned_model_pr_ace.evaluate(f_val, y_val))138
139_, _ = ipca.get_groupacc(140finetuned_model_pr_ace,141concept_arraynew,142f_train,143f_val,144concept,145n_concept,146n_cluster,147n0,148verbose=False)149
150concept_matrix_ace2 = toy_helper.get_ace_concept(concept_arraynew2, dense2,151predict, f_val, n_concept)152
153finetuned_model_pr_ace2 = ipca.ipca_model(154concept_arraynew2,155dense2,156predict,157f_train,158y_train,159f_val,160y_val,161n_concept,162verbose=True,163epochs=0,164metric='accuracy')165
166finetuned_model_pr_ace2.layers[-5].set_weights([concept_matrix_ace2])167
168print(finetuned_model_pr_ace2.evaluate(f_val, y_val))169
170_, _ = ipca.get_groupacc(171finetuned_model_pr_ace2,172concept_arraynew2,173f_train,174f_val,175concept,176n_concept,177n_cluster,178n0,179verbose=False)180
181concept_matrix_pca = toy_helper.get_pca_concept(f_train, n_concept)182
183finetuned_model_pr_pca = ipca.ipca_model(184concept_arraynew,185dense2,186predict,187f_train,188y_train,189f_val,190y_val,191n_concept,192verbose=True,193epochs=0,194metric='accuracy')195
196finetuned_model_pr_pca.layers[-5].set_weights([concept_matrix_pca])197
198print(finetuned_model_pr_pca.evaluate(f_val, y_val))199
200_, _ = ipca.get_groupacc(201finetuned_model_pr_pca,202concept_arraynew,203f_train,204f_val,205concept,206n_concept,207n_cluster,208n0,209verbose=False)210
211
212if __name__ == '__main__':213app.run(main)214