google-research
195 строк · 6.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Run Classifier to score all data.
17
18Data scorer wtih classifier.
19
20This file is intended for a dataset that is split into 14 chunks.
21"""
22
23import csv
24import os
25import pickle
26from typing import Sequence
27
28from absl import app
29from absl import flags
30import jax
31import numpy as np
32from scipy.special import softmax
33import tensorflow as tf
34import transformers
35
36from data_selection.wmt import decode
37from data_selection.wmt import input_pipeline
38
39
40tf.compat.v1.enable_eager_execution()
41
42FLAGS = flags.FLAGS
43
44flags.DEFINE_string(
45'save_dir', default=None,
46help='Directory to store scores data.')
47flags.DEFINE_integer(
48'slice', default=0,
49help='Which slice of data to process.')
50flags.DEFINE_string(
51'bert_base_dir', default=None,
52help='Directory of German BERT.')
53flags.DEFINE_string(
54'bert_clf_dir', default=None,
55help='Directory of German BERT domain classifier.')
56flags.DEFINE_string(
57'target_text', default=None,
58help='Filename with target text. This data will be labeled by model.')
59flags.DEFINE_string(
60'dataset_name', default=None,
61help='Name of dataset if targets not provided.')
62flags.DEFINE_string(
63'data_dir', default=None,
64help='Dataset dir if targets not provided.')
65flags.DEFINE_string(
66'vocab_path', default=None,
67help='Vocab file if targets not provided.')
68flags.DEFINE_bool(
69'split_tokenizer', default=False,
70help='Use 1 or 2 tokenizers if targets not provided.')
71flags.DEFINE_bool(
72'clf_inputs', default=False,
73help='Classify the input language.')
74flags.DEFINE_bool(
75'clf_targets', default=True,
76help='Classify the target language.')
77flags.DEFINE_integer(
78'paracrawl_size', default=0,
79help='Number of examples to sample from paracrawl.')
80
81PROC_SIZE = 300000
82
83
84def main(argv):
85if len(argv) > 1:
86raise app.UsageError('Too many command-line arguments.')
87
88# Grab pretrain text data
89if FLAGS.target_text:
90targets_decoded_pt = []
91for i in range(1, 9):
92with tf.io.gfile.GFile(FLAGS.target_text % i, 'rb') as f:
93pt_targs_tmp = pickle.load(f)
94targets_decoded_pt.extend(pt_targs_tmp)
95else:
96train_ds, (encoder_in, encoder_tgt) = input_pipeline.get_wmt_is_datasets(
97n_devices=jax.local_device_count(),
98dataset_name=FLAGS.dataset_name,
99shard_idx=jax.process_index(),
100shard_count=jax.process_count(),
101data_dir=FLAGS.data_dir,
102vocab_path=FLAGS.vocab_path,
103target_vocab_size=32000,
104batch_size=1024,
105max_length=256,
106paracrawl_size=FLAGS.paracrawl_size,
107split_tokenizer=FLAGS.split_tokenizer)
108
109train_data = iter(train_ds)
110eos_id = decode.EOS_ID
111def decode_tokens(encoder, toks):
112valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
113return encoder.detokenize(valid_toks).numpy().decode('utf-8')
114targets = []
115inputs = []
116for x in train_data:
117trg = x['targets']._numpy() # pylint:disable=protected-access
118ins = x['inputs']._numpy() # pylint:disable=protected-access
119targets.append(trg)
120inputs.append(ins)
121
122# flatten targets_decoded_pt
123# pylint:disable=g-complex-comprehension
124targets_flat = [t for batch_t in targets for t in batch_t]
125inputs_flat = [t for batch_t in inputs for t in batch_t]
126# pylint:enable=g-complex-comprehension
127
128# decode only the slice for this one
129targets_decoded_pt = []
130start = PROC_SIZE * FLAGS.slice
131end = PROC_SIZE * (FLAGS.slice + 1)
132if FLAGS.slice == 14:
133end = 9999999
134for i, x in enumerate(targets_flat[start:end]):
135if FLAGS.clf_inputs:
136input_decode = decode_tokens(encoder_in, inputs_flat[i + start])
137if FLAGS.clf_targets:
138target_decode = decode_tokens(encoder_tgt, x)
139if FLAGS.clf_inputs and FLAGS.clf_targets:
140decode_tok = input_decode + ' [SEP] ' + target_decode
141else:
142decode_tok = target_decode if FLAGS.clf_targets else input_decode
143targets_decoded_pt.append(decode_tok)
144
145# Load model
146cache_dir = '/tmp/' # model weights get temporarily written to this directory
147path = FLAGS.bert_base_dir
148trained_path = FLAGS.bert_clf_dir
149config = transformers.BertConfig.from_pretrained(
150os.path.join(trained_path, 'config.json'), num_labels=2,
151cache_dir=cache_dir)
152tokenizer = transformers.BertTokenizer.from_pretrained(
153path, cache_dir=cache_dir)
154model = transformers.TFBertForSequenceClassification.from_pretrained(
155os.path.join(trained_path, 'tf_model.h5'), config=config,
156cache_dir=cache_dir)
157
158if FLAGS.target_text:
159# If we read the entire dataset from text, select the slice to encode
160start = PROC_SIZE * FLAGS.slice
161end = PROC_SIZE * (FLAGS.slice + 1)
162if FLAGS.slice == 14:
163end = 9999999
164input_targets = targets_decoded_pt[start:end]
165else:
166# the targets were decoded above so just use the ones that were decoded
167input_targets = targets_decoded_pt
168encoding = tokenizer(
169input_targets,
170return_tensors='tf',
171padding=True,
172truncation=True,
173max_length=512)
174
175train_dataset = tf.data.Dataset.from_tensor_slices((
176dict(encoding),
177))
178batch_size = 256
179if FLAGS.clf_inputs and FLAGS.clf_targets:
180# multiling model is larger
181batch_size = 128
182train_dataset = train_dataset.batch(batch_size)
183logits = model.predict(train_dataset)
184
185probs = softmax(logits.logits, axis=1)
186
187clf_score_name = FLAGS.save_dir + '/CLR_scores_' + str(FLAGS.slice) + '.csv'
188with tf.io.gfile.GFile(clf_score_name, 'w') as f:
189writer = csv.writer(f)
190for p in probs:
191writer.writerow([p[1]])
192
193
194if __name__ == '__main__':
195app.run(main)
196