google-research
119 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""This is the code to run the discover concept algorithm in the toy dataset."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import pickle
21from absl import app
22import ipca
23import numpy as np
24import toy_helper
25
26
27def main(_):
28n_concept = 5
29n_cluster = 5
30n = 60000
31n0 = int(n * 0.8)
32
33pretrain = True
34# Loads data.
35x, y, concept = toy_helper.load_xyconcept(n, pretrain)
36if not pretrain:
37x_train = x[:n0, :]
38x_val = x[n0:, :]
39y_train = y[:n0, :]
40y_val = y[n0:, :]
41all_feature_dense = np.load('all_feature_dense.npy')
42f_train = all_feature_dense[:n0, :]
43f_val = all_feature_dense[n0:, :]
44# Loads model
45if not pretrain:
46dense2, predict, _ = toy_helper.load_model(
47x_train, y_train, x_val, y_val, pretrain=pretrain)
48else:
49dense2, predict, _ = toy_helper.load_model(_, _, _, _, pretrain=pretrain)
50# Loads concept
51concept_arraynew = np.load('concept_arraynew.npy')
52concept_arraynew2 = np.load('concept_arraynew2.npy')
53# Returns discovered concepts with true clusters
54finetuned_model_pr = ipca.ipca_model(concept_arraynew2, dense2, predict,
55f_train, y_train, f_val, y_val,
56n_concept)
57num_epoch = 5
58for _ in range(num_epoch):
59finetuned_model_pr.fit(
60f_train,
61y_train,
62batch_size=50,
63epochs=10,
64verbose=True,
65validation_data=(f_val, y_val))
66# Evaluates groupacc and get concept_matrix
67concept_matrix, _ = ipca.get_groupacc(
68finetuned_model_pr,
69concept_arraynew2,
70f_train,
71f_val,
72concept,
73n_concept,
74n_cluster,
75n0,
76verbose=False)
77# Saves concept matrix
78with open('concept_matrix_sup.pickle', 'wb') as handle:
79pickle.dump(concept_matrix, handle, protocol=pickle.HIGHEST_PROTOCOL)
80# Plots nearest neighbors
81feature_sp1 = np.load('feature_sp1.npy')
82segment_sp1 = np.load('segment_sp1.npy')
83feature_sp1_1000 = feature_sp1[:1000]
84segment_sp1_1000 = segment_sp1[:1000]
85ipca.plot_nearestneighbor(concept_matrix, feature_sp1_1000, segment_sp1_1000)
86
87# Discovered concepts with self-discovered clusters.
88finetuned_model_pr = ipca.ipca_model(concept_arraynew, dense2, predict,
89f_train, y_train, f_val, y_val,
90n_concept)
91num_epoch = 5
92for _ in range(num_epoch):
93finetuned_model_pr.fit(
94f_train,
95y_train,
96batch_size=50,
97epochs=10,
98verbose=True,
99validation_data=(f_val, y_val))
100concept_matrix, _ = toy_helper.get_groupacc(
101finetuned_model_pr,
102concept_arraynew,
103f_train,
104f_val,
105concept,
106n_concept,
107n_cluster,
108n0,
109verbose=False)
110# Saves concept matrix.
111with open('concept_matrix_unsup.pickle', 'wb') as handle:
112pickle.dump(concept_matrix, handle, protocol=pickle.HIGHEST_PROTOCOL)
113# Plots nearest neighbors.
114toy_helper.plot_nearestneighbor(concept_matrix,
115feature_sp1_1000, segment_sp1_1000)
116
117
118if __name__ == '__main__':
119app.run(main)
120