google-research
269 строк · 8.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Run Classifier to score all data.
17
18Data scorer wtih classifier.
19
20This file is intended for a dataset that is split into 14 chunks.
21"""
22
23import os24import pathlib25import pickle26from typing import Any, Dict, Sequence, Union, Optional27
28from absl import app29from absl import flags30from absl import logging31import numpy as np32import numpy.typing as npt33import tensorflow as tf34import transformers35
36
37
38FLAGS = flags.FLAGS39
40STRUCT = 'struct'41ONE_HOT = 'onehot'42
43flags.DEFINE_integer(44'train_data_size', default=500_000,45help='Size of training data.')46flags.DEFINE_integer(47'epochs', default=5,48help='Epochs of training.')49flags.DEFINE_integer(50'batch_size', default=64,51help='Batch size.')52flags.DEFINE_integer(53'eval_freq', default=1_000,54help='Num steps between evals.')55flags.DEFINE_string(56'loss', default=STRUCT,57help='Type of loss to use.')58flags.DEFINE_string(59'cns_save_dir', default='routing/domain_clf/',60help='Path to save dir.')61flags.DEFINE_string(62'bert_path', default='bert_base',63help='Path to bert base.')64flags.DEFINE_string(65'routing_path', default='routing/',66help='Path to save logs and losses.')67
68
69class CNSCheckpointCallback(tf.keras.callbacks.Callback):70"""Write checkpoints to CNS."""71
72def __init__(self,73tmp_save_dir = '/tmp/model_ckpt',74cns_save_dir = 'routing/domain_clf/'):75super().__init__()76self.best_loss = 1000077self.tmp_save_dir = tmp_save_dir78self.cns_save_dir = cns_save_dir79
80def _save_model(self, new_loss):81pathlib.Path(self.tmp_save_dir).mkdir(parents=True, exist_ok=True)82self.model.save_pretrained(self.tmp_save_dir)83for filename in tf.io.gfile.glob(self.tmp_save_dir + '/*'):84print(filename)85base_filename = os.path.basename(filename)86if not tf.io.gfile.exists(self.cns_save_dir):87tf.io.gfile.mkdir(self.cns_save_dir)88tf.io.gfile.copy(89filename, self.cns_save_dir + '/' + base_filename, overwrite=True)90with tf.io.gfile.GFile(91self.cns_save_dir + '/chkpt{}.txt'.format(new_loss), 'w') as f:92f.write(str(new_loss))93
94def on_test_end(self, logs = None):95new_loss = logs['loss']96if new_loss < self.best_loss:97self.best_loss = new_loss98self._save_model(new_loss)99
100
101def custom_loss_function(y_true, y_pred):102"""Structed loss function."""103lowest_loss = tf.reduce_min(y_true, axis=1)104lowest_index = tf.argmin(y_true, axis=1)105indexes = tf.expand_dims(lowest_index, 1)106rows = tf.expand_dims(tf.range(tf.shape(indexes)[0], dtype=tf.int64), 1)107ind = tf.concat([rows, indexes], axis=1)108s_xystar = tf.gather_nd(y_pred, ind)109
110# cost of each point111cost = tf.subtract(y_true, tf.expand_dims(lowest_loss, 1))112addition = tf.add(cost, y_pred)113sub = tf.subtract(addition, tf.expand_dims(s_xystar, 1))114# max_y115max_y = tf.reduce_max(sub, axis=1)116return tf.maximum(0.0, max_y)117
118
119class CustomAccuracy(tf.keras.metrics.Metric):120"""Compute accuracy with explicit score inputs."""121
122def __init__(self, name = 'custom_acc', **kwargs):123super(CustomAccuracy, self).__init__(name=name, **kwargs)124self.custom_acc = tf.keras.metrics.Accuracy(name='custom_acc', dtype=None)125
126def update_state(127self,128y_true,129y_pred,130sample_weight = None131):132lowest_index = tf.argmin(y_true, axis=1)133highest_logit = tf.argmax(y_pred, axis=1)134self.custom_acc.update_state(lowest_index, highest_logit)135
136def result(self):137return self.custom_acc.result()138
139def reset_states(self):140self.custom_acc.reset_states()141
142
143def main(argv):144if len(argv) > 1:145raise app.UsageError('Too many command-line arguments.')146
147with tf.io.gfile.GFile(148FLAGS.cns_save_dir + '/analysis_full/data.pkl', 'rb') as f:149wmt_labels = pickle.load(f)150
151strategy = tf.distribute.MirroredStrategy()152with strategy.scope():153path = FLAGS.bert_path # pretrained model (ie. bert-base)154cache_dir = '/tmp/'155config = transformers.BertConfig.from_pretrained(156os.path.join(path, 'config.json'), num_labels=100, cache_dir=cache_dir)157tokenizer = transformers.BertTokenizer.from_pretrained(158path, cache_dir=cache_dir)159model = transformers.TFBertForSequenceClassification.from_pretrained(160os.path.join(path, 'tf_model.h5'), config=config, cache_dir=cache_dir)161
162with tf.device('/CPU:0'):163all_train_ds = []164all_eval_ds = []165for i in range(100):166data_dir = (FLAGS.routing_path + '/cluster_data/k100/id_{}/'167.format(i))168wmt_train = 'test_large.tsv'169train_files = [data_dir + '/' + wmt_train]170
171train_data = tf.data.experimental.CsvDataset(172train_files,173record_defaults=[tf.string, tf.string],174field_delim='\t',175use_quote_delim=False)176
177def to_features_dict(eng, _):178return eng179
180train_data = train_data.map(to_features_dict)181
182all_scores = [np.array(wmt_labels[i][j]) for j in range(100)]183all_scores = np.stack(all_scores)184if FLAGS.loss == STRUCT:185split_scores = tf.split(all_scores, all_scores.shape[1], axis=1)186split_scores = [tf.reshape(scores, -1) for scores in split_scores]187label = tf.data.Dataset.from_tensor_slices(split_scores)188elif FLAGS.loss == ONE_HOT:189labels = np.argmin(all_scores, axis=0)190label = tf.data.Dataset.from_tensor_slices(labels)191train_data_w_label = tf.data.Dataset.zip((train_data, label))192eval_data_w_label = train_data_w_label.take(100)193train_data_w_label = train_data_w_label.skip(100)194
195all_train_ds.append(train_data_w_label)196all_eval_ds.append(eval_data_w_label)197
198sample_dataset = tf.data.experimental.sample_from_datasets(all_train_ds)199eval_sample_dataset = tf.data.experimental.sample_from_datasets(all_eval_ds)200
201with tf.device('/CPU:0'):202text = []203labels = []204for ex in sample_dataset:205text.append(str(ex[0].numpy()))206labels.append(ex[1].numpy())207if len(labels) > FLAGS.train_data_size:208break209
210encoding = tokenizer(211text,212return_tensors='tf',213padding=True,214truncation=True,215max_length=128)216train_dataset = tf.data.Dataset.from_tensor_slices((dict(encoding), labels))217train_dataset = train_dataset.batch(218FLAGS.batch_size).shuffle(100_000).repeat(20)219
220eval_text = []221eval_labels = []222for ex in eval_sample_dataset:223eval_text.append(str(ex[0].numpy()))224eval_labels.append(ex[1].numpy())225if len(eval_labels) > 5000:226break227
228eval_encoding = tokenizer(229eval_text,230return_tensors='tf',231padding=True,232truncation=True,233max_length=128)234eval_dataset = tf.data.Dataset.from_tensor_slices(235(dict(eval_encoding), eval_labels))236eval_dataset = eval_dataset.batch(FLAGS.batch_size)237
238num_train_steps = int(FLAGS.train_data_size / FLAGS.batch_size) * FLAGS.epochs239decay_schedule = tf.keras.optimizers.schedules.PolynomialDecay(240initial_learning_rate=4e-5,241decay_steps=num_train_steps,242end_learning_rate=0)243
244with strategy.scope():245optimizer = tf.keras.optimizers.Adam(learning_rate=decay_schedule)246if FLAGS.loss == ONE_HOT:247model.compile(248optimizer=optimizer, loss=model.compute_loss, metrics=['accuracy'])249elif FLAGS.loss == STRUCT:250model.compile(251optimizer=optimizer, loss=custom_loss_function,252metrics=[CustomAccuracy()])253
254steps_per_epoch = int(FLAGS.train_data_size / FLAGS.batch_size)255num_meta_epochs = int(steps_per_epoch / FLAGS.eval_freq * FLAGS.epochs)256logging.info('Training %d epochs of %d steps each', num_meta_epochs,257steps_per_epoch)258model.fit(259train_dataset,260epochs=num_meta_epochs,261steps_per_epoch=FLAGS.eval_freq,262validation_freq=1,263validation_data=eval_dataset,264callbacks=CNSCheckpointCallback(cns_save_dir=FLAGS.cns_save_dir),265verbose=2) # 1 line per epoch logging266
267
268if __name__ == '__main__':269app.run(main)270