google-research

Форк
0
119 строк · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""This is the code to run the discover concept algorithm in the toy dataset."""
17
from __future__ import absolute_import
18
from __future__ import division
19
from __future__ import print_function
20
import pickle
21
from absl import app
22
import ipca
23
import numpy as np
24
import toy_helper
25

26

27
def main(_):
28
  n_concept = 5
29
  n_cluster = 5
30
  n = 60000
31
  n0 = int(n * 0.8)
32

33
  pretrain = True
34
  # Loads data.
35
  x, y, concept = toy_helper.load_xyconcept(n, pretrain)
36
  if not pretrain:
37
    x_train = x[:n0, :]
38
    x_val = x[n0:, :]
39
  y_train = y[:n0, :]
40
  y_val = y[n0:, :]
41
  all_feature_dense = np.load('all_feature_dense.npy')
42
  f_train = all_feature_dense[:n0, :]
43
  f_val = all_feature_dense[n0:, :]
44
  # Loads model
45
  if not pretrain:
46
    dense2, predict, _ = toy_helper.load_model(
47
        x_train, y_train, x_val, y_val, pretrain=pretrain)
48
  else:
49
    dense2, predict, _ = toy_helper.load_model(_, _, _, _, pretrain=pretrain)
50
  # Loads concept
51
  concept_arraynew = np.load('concept_arraynew.npy')
52
  concept_arraynew2 = np.load('concept_arraynew2.npy')
53
  # Returns discovered concepts with true clusters
54
  finetuned_model_pr = ipca.ipca_model(concept_arraynew2, dense2, predict,
55
                                       f_train, y_train, f_val, y_val,
56
                                       n_concept)
57
  num_epoch = 5
58
  for _ in range(num_epoch):
59
    finetuned_model_pr.fit(
60
        f_train,
61
        y_train,
62
        batch_size=50,
63
        epochs=10,
64
        verbose=True,
65
        validation_data=(f_val, y_val))
66
  # Evaluates groupacc and get concept_matrix
67
  concept_matrix, _ = ipca.get_groupacc(
68
      finetuned_model_pr,
69
      concept_arraynew2,
70
      f_train,
71
      f_val,
72
      concept,
73
      n_concept,
74
      n_cluster,
75
      n0,
76
      verbose=False)
77
  # Saves concept matrix
78
  with open('concept_matrix_sup.pickle', 'wb') as handle:
79
    pickle.dump(concept_matrix, handle, protocol=pickle.HIGHEST_PROTOCOL)
80
  # Plots nearest neighbors
81
  feature_sp1 = np.load('feature_sp1.npy')
82
  segment_sp1 = np.load('segment_sp1.npy')
83
  feature_sp1_1000 = feature_sp1[:1000]
84
  segment_sp1_1000 = segment_sp1[:1000]
85
  ipca.plot_nearestneighbor(concept_matrix, feature_sp1_1000, segment_sp1_1000)
86

87
  # Discovered concepts with self-discovered clusters.
88
  finetuned_model_pr = ipca.ipca_model(concept_arraynew, dense2, predict,
89
                                       f_train, y_train, f_val, y_val,
90
                                       n_concept)
91
  num_epoch = 5
92
  for _ in range(num_epoch):
93
    finetuned_model_pr.fit(
94
        f_train,
95
        y_train,
96
        batch_size=50,
97
        epochs=10,
98
        verbose=True,
99
        validation_data=(f_val, y_val))
100
  concept_matrix, _ = toy_helper.get_groupacc(
101
      finetuned_model_pr,
102
      concept_arraynew,
103
      f_train,
104
      f_val,
105
      concept,
106
      n_concept,
107
      n_cluster,
108
      n0,
109
      verbose=False)
110
  # Saves concept matrix.
111
  with open('concept_matrix_unsup.pickle', 'wb') as handle:
112
    pickle.dump(concept_matrix, handle, protocol=pickle.HIGHEST_PROTOCOL)
113
  # Plots nearest neighbors.
114
  toy_helper.plot_nearestneighbor(concept_matrix,
115
                                  feature_sp1_1000, segment_sp1_1000)
116

117

118
if __name__ == '__main__':
119
  app.run(main)
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.