google-research

Форк
0
269 строк · 8.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Run Classifier to score all data.
17

18
Data scorer wtih classifier.
19

20
This file is intended for a dataset that is split into 14 chunks.
21
"""
22

23
import os
24
import pathlib
25
import pickle
26
from typing import Any, Dict, Sequence, Union, Optional
27

28
from absl import app
29
from absl import flags
30
from absl import logging
31
import numpy as np
32
import numpy.typing as npt
33
import tensorflow as tf
34
import transformers
35

36

37

38
FLAGS = flags.FLAGS
39

40
STRUCT = 'struct'
41
ONE_HOT = 'onehot'
42

43
flags.DEFINE_integer(
44
    'train_data_size', default=500_000,
45
    help='Size of training data.')
46
flags.DEFINE_integer(
47
    'epochs', default=5,
48
    help='Epochs of training.')
49
flags.DEFINE_integer(
50
    'batch_size', default=64,
51
    help='Batch size.')
52
flags.DEFINE_integer(
53
    'eval_freq', default=1_000,
54
    help='Num steps between evals.')
55
flags.DEFINE_string(
56
    'loss', default=STRUCT,
57
    help='Type of loss to use.')
58
flags.DEFINE_string(
59
    'cns_save_dir', default='routing/domain_clf/',
60
    help='Path to save dir.')
61
flags.DEFINE_string(
62
    'bert_path', default='bert_base',
63
    help='Path to bert base.')
64
flags.DEFINE_string(
65
    'routing_path', default='routing/',
66
    help='Path to save logs and losses.')
67

68

69
class CNSCheckpointCallback(tf.keras.callbacks.Callback):
70
  """Write checkpoints to CNS."""
71

72
  def __init__(self,
73
               tmp_save_dir = '/tmp/model_ckpt',
74
               cns_save_dir = 'routing/domain_clf/'):
75
    super().__init__()
76
    self.best_loss = 10000
77
    self.tmp_save_dir = tmp_save_dir
78
    self.cns_save_dir = cns_save_dir
79

80
  def _save_model(self, new_loss):
81
    pathlib.Path(self.tmp_save_dir).mkdir(parents=True, exist_ok=True)
82
    self.model.save_pretrained(self.tmp_save_dir)
83
    for filename in tf.io.gfile.glob(self.tmp_save_dir + '/*'):
84
      print(filename)
85
      base_filename = os.path.basename(filename)
86
      if not tf.io.gfile.exists(self.cns_save_dir):
87
        tf.io.gfile.mkdir(self.cns_save_dir)
88
      tf.io.gfile.copy(
89
          filename, self.cns_save_dir + '/' + base_filename, overwrite=True)
90
      with tf.io.gfile.GFile(
91
          self.cns_save_dir + '/chkpt{}.txt'.format(new_loss), 'w') as f:
92
        f.write(str(new_loss))
93

94
  def on_test_end(self, logs = None):
95
    new_loss = logs['loss']
96
    if new_loss < self.best_loss:
97
      self.best_loss = new_loss
98
      self._save_model(new_loss)
99

100

101
def custom_loss_function(y_true, y_pred):
102
  """Structed loss function."""
103
  lowest_loss = tf.reduce_min(y_true, axis=1)
104
  lowest_index = tf.argmin(y_true, axis=1)
105
  indexes = tf.expand_dims(lowest_index, 1)
106
  rows = tf.expand_dims(tf.range(tf.shape(indexes)[0], dtype=tf.int64), 1)
107
  ind = tf.concat([rows, indexes], axis=1)
108
  s_xystar = tf.gather_nd(y_pred, ind)
109

110
  # cost of each point
111
  cost = tf.subtract(y_true, tf.expand_dims(lowest_loss, 1))
112
  addition = tf.add(cost, y_pred)
113
  sub = tf.subtract(addition, tf.expand_dims(s_xystar, 1))
114
  # max_y
115
  max_y = tf.reduce_max(sub, axis=1)
116
  return tf.maximum(0.0, max_y)
117

118

119
class CustomAccuracy(tf.keras.metrics.Metric):
120
  """Compute accuracy with explicit score inputs."""
121

122
  def __init__(self, name = 'custom_acc', **kwargs):
123
    super(CustomAccuracy, self).__init__(name=name, **kwargs)
124
    self.custom_acc = tf.keras.metrics.Accuracy(name='custom_acc', dtype=None)
125

126
  def update_state(
127
      self,
128
      y_true,
129
      y_pred,
130
      sample_weight = None
131
  ):
132
    lowest_index = tf.argmin(y_true, axis=1)
133
    highest_logit = tf.argmax(y_pred, axis=1)
134
    self.custom_acc.update_state(lowest_index, highest_logit)
135

136
  def result(self):
137
    return self.custom_acc.result()
138

139
  def reset_states(self):
140
    self.custom_acc.reset_states()
141

142

143
def main(argv):
144
  if len(argv) > 1:
145
    raise app.UsageError('Too many command-line arguments.')
146

147
  with tf.io.gfile.GFile(
148
      FLAGS.cns_save_dir + '/analysis_full/data.pkl', 'rb') as f:
149
    wmt_labels = pickle.load(f)
150

151
  strategy = tf.distribute.MirroredStrategy()
152
  with strategy.scope():
153
    path = FLAGS.bert_path  # pretrained model (ie. bert-base)
154
    cache_dir = '/tmp/'
155
    config = transformers.BertConfig.from_pretrained(
156
        os.path.join(path, 'config.json'), num_labels=100, cache_dir=cache_dir)
157
    tokenizer = transformers.BertTokenizer.from_pretrained(
158
        path, cache_dir=cache_dir)
159
    model = transformers.TFBertForSequenceClassification.from_pretrained(
160
        os.path.join(path, 'tf_model.h5'), config=config, cache_dir=cache_dir)
161

162
  with tf.device('/CPU:0'):
163
    all_train_ds = []
164
    all_eval_ds = []
165
    for i in range(100):
166
      data_dir = (FLAGS.routing_path + '/cluster_data/k100/id_{}/'
167
                  .format(i))
168
      wmt_train = 'test_large.tsv'
169
      train_files = [data_dir + '/' + wmt_train]
170

171
      train_data = tf.data.experimental.CsvDataset(
172
          train_files,
173
          record_defaults=[tf.string, tf.string],
174
          field_delim='\t',
175
          use_quote_delim=False)
176

177
      def to_features_dict(eng, _):
178
        return eng
179

180
      train_data = train_data.map(to_features_dict)
181

182
      all_scores = [np.array(wmt_labels[i][j]) for j in range(100)]
183
      all_scores = np.stack(all_scores)
184
      if FLAGS.loss == STRUCT:
185
        split_scores = tf.split(all_scores, all_scores.shape[1], axis=1)
186
        split_scores = [tf.reshape(scores, -1) for scores in split_scores]
187
        label = tf.data.Dataset.from_tensor_slices(split_scores)
188
      elif FLAGS.loss == ONE_HOT:
189
        labels = np.argmin(all_scores, axis=0)
190
        label = tf.data.Dataset.from_tensor_slices(labels)
191
      train_data_w_label = tf.data.Dataset.zip((train_data, label))
192
      eval_data_w_label = train_data_w_label.take(100)
193
      train_data_w_label = train_data_w_label.skip(100)
194

195
      all_train_ds.append(train_data_w_label)
196
      all_eval_ds.append(eval_data_w_label)
197

198
    sample_dataset = tf.data.experimental.sample_from_datasets(all_train_ds)
199
    eval_sample_dataset = tf.data.experimental.sample_from_datasets(all_eval_ds)
200

201
  with tf.device('/CPU:0'):
202
    text = []
203
    labels = []
204
    for ex in sample_dataset:
205
      text.append(str(ex[0].numpy()))
206
      labels.append(ex[1].numpy())
207
      if len(labels) > FLAGS.train_data_size:
208
        break
209

210
    encoding = tokenizer(
211
        text,
212
        return_tensors='tf',
213
        padding=True,
214
        truncation=True,
215
        max_length=128)
216
    train_dataset = tf.data.Dataset.from_tensor_slices((dict(encoding), labels))
217
    train_dataset = train_dataset.batch(
218
        FLAGS.batch_size).shuffle(100_000).repeat(20)
219

220
    eval_text = []
221
    eval_labels = []
222
    for ex in eval_sample_dataset:
223
      eval_text.append(str(ex[0].numpy()))
224
      eval_labels.append(ex[1].numpy())
225
      if len(eval_labels) > 5000:
226
        break
227

228
    eval_encoding = tokenizer(
229
        eval_text,
230
        return_tensors='tf',
231
        padding=True,
232
        truncation=True,
233
        max_length=128)
234
    eval_dataset = tf.data.Dataset.from_tensor_slices(
235
        (dict(eval_encoding), eval_labels))
236
    eval_dataset = eval_dataset.batch(FLAGS.batch_size)
237

238
  num_train_steps = int(FLAGS.train_data_size / FLAGS.batch_size) * FLAGS.epochs
239
  decay_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
240
      initial_learning_rate=4e-5,
241
      decay_steps=num_train_steps,
242
      end_learning_rate=0)
243

244
  with strategy.scope():
245
    optimizer = tf.keras.optimizers.Adam(learning_rate=decay_schedule)
246
    if FLAGS.loss == ONE_HOT:
247
      model.compile(
248
          optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])
249
    elif FLAGS.loss == STRUCT:
250
      model.compile(
251
          optimizer=optimizer, loss=custom_loss_function,
252
          metrics=[CustomAccuracy()])
253

254
  steps_per_epoch = int(FLAGS.train_data_size / FLAGS.batch_size)
255
  num_meta_epochs = int(steps_per_epoch / FLAGS.eval_freq * FLAGS.epochs)
256
  logging.info('Training %d epochs of %d steps each', num_meta_epochs,
257
               steps_per_epoch)
258
  model.fit(
259
      train_dataset,
260
      epochs=num_meta_epochs,
261
      steps_per_epoch=FLAGS.eval_freq,
262
      validation_freq=1,
263
      validation_data=eval_dataset,
264
      callbacks=CNSCheckpointCallback(cns_save_dir=FLAGS.cns_save_dir),
265
      verbose=2)  # 1 line per epoch logging
266

267

268
if __name__ == '__main__':
269
  app.run(main)
270

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.