google-research

Форк
0
195 строк · 6.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Run Classifier to score all data.
17

18
Data scorer wtih classifier.
19

20
This file is intended for a dataset that is split into 14 chunks.
21
"""
22

23
import csv
24
import os
25
import pickle
26
from typing import Sequence
27

28
from absl import app
29
from absl import flags
30
import jax
31
import numpy as np
32
from scipy.special import softmax
33
import tensorflow as tf
34
import transformers
35

36
from data_selection.wmt import decode
37
from data_selection.wmt import input_pipeline
38

39

40
tf.compat.v1.enable_eager_execution()
41

42
FLAGS = flags.FLAGS
43

44
flags.DEFINE_string(
45
    'save_dir', default=None,
46
    help='Directory to store scores data.')
47
flags.DEFINE_integer(
48
    'slice', default=0,
49
    help='Which slice of data to process.')
50
flags.DEFINE_string(
51
    'bert_base_dir', default=None,
52
    help='Directory of German BERT.')
53
flags.DEFINE_string(
54
    'bert_clf_dir', default=None,
55
    help='Directory of German BERT domain classifier.')
56
flags.DEFINE_string(
57
    'target_text', default=None,
58
    help='Filename with target text. This data will be labeled by model.')
59
flags.DEFINE_string(
60
    'dataset_name', default=None,
61
    help='Name of dataset if targets not provided.')
62
flags.DEFINE_string(
63
    'data_dir', default=None,
64
    help='Dataset dir if targets not provided.')
65
flags.DEFINE_string(
66
    'vocab_path', default=None,
67
    help='Vocab file if targets not provided.')
68
flags.DEFINE_bool(
69
    'split_tokenizer', default=False,
70
    help='Use 1 or 2 tokenizers if targets not provided.')
71
flags.DEFINE_bool(
72
    'clf_inputs', default=False,
73
    help='Classify the input language.')
74
flags.DEFINE_bool(
75
    'clf_targets', default=True,
76
    help='Classify the target language.')
77
flags.DEFINE_integer(
78
    'paracrawl_size', default=0,
79
    help='Number of examples to sample from paracrawl.')
80

81
PROC_SIZE = 300000
82

83

84
def main(argv):
85
  if len(argv) > 1:
86
    raise app.UsageError('Too many command-line arguments.')
87

88
  # Grab pretrain text data
89
  if FLAGS.target_text:
90
    targets_decoded_pt = []
91
    for i in range(1, 9):
92
      with tf.io.gfile.GFile(FLAGS.target_text % i, 'rb') as f:
93
        pt_targs_tmp = pickle.load(f)
94
      targets_decoded_pt.extend(pt_targs_tmp)
95
  else:
96
    train_ds, (encoder_in, encoder_tgt) = input_pipeline.get_wmt_is_datasets(
97
        n_devices=jax.local_device_count(),
98
        dataset_name=FLAGS.dataset_name,
99
        shard_idx=jax.process_index(),
100
        shard_count=jax.process_count(),
101
        data_dir=FLAGS.data_dir,
102
        vocab_path=FLAGS.vocab_path,
103
        target_vocab_size=32000,
104
        batch_size=1024,
105
        max_length=256,
106
        paracrawl_size=FLAGS.paracrawl_size,
107
        split_tokenizer=FLAGS.split_tokenizer)
108

109
    train_data = iter(train_ds)
110
    eos_id = decode.EOS_ID
111
    def decode_tokens(encoder, toks):
112
      valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32)
113
      return encoder.detokenize(valid_toks).numpy().decode('utf-8')
114
    targets = []
115
    inputs = []
116
    for x in train_data:
117
      trg = x['targets']._numpy()  # pylint:disable=protected-access
118
      ins = x['inputs']._numpy()  # pylint:disable=protected-access
119
      targets.append(trg)
120
      inputs.append(ins)
121

122
    # flatten targets_decoded_pt
123
    # pylint:disable=g-complex-comprehension
124
    targets_flat = [t for batch_t in targets for t in batch_t]
125
    inputs_flat = [t for batch_t in inputs for t in batch_t]
126
    # pylint:enable=g-complex-comprehension
127

128
    # decode only the slice for this one
129
    targets_decoded_pt = []
130
    start = PROC_SIZE * FLAGS.slice
131
    end = PROC_SIZE * (FLAGS.slice + 1)
132
    if FLAGS.slice == 14:
133
      end = 9999999
134
    for i, x in enumerate(targets_flat[start:end]):
135
      if FLAGS.clf_inputs:
136
        input_decode = decode_tokens(encoder_in, inputs_flat[i + start])
137
      if FLAGS.clf_targets:
138
        target_decode = decode_tokens(encoder_tgt, x)
139
      if FLAGS.clf_inputs and FLAGS.clf_targets:
140
        decode_tok = input_decode + ' [SEP] ' + target_decode
141
      else:
142
        decode_tok = target_decode if FLAGS.clf_targets else input_decode
143
      targets_decoded_pt.append(decode_tok)
144

145
  # Load model
146
  cache_dir = '/tmp/'  # model weights get temporarily written to this directory
147
  path = FLAGS.bert_base_dir
148
  trained_path = FLAGS.bert_clf_dir
149
  config = transformers.BertConfig.from_pretrained(
150
      os.path.join(trained_path, 'config.json'), num_labels=2,
151
      cache_dir=cache_dir)
152
  tokenizer = transformers.BertTokenizer.from_pretrained(
153
      path, cache_dir=cache_dir)
154
  model = transformers.TFBertForSequenceClassification.from_pretrained(
155
      os.path.join(trained_path, 'tf_model.h5'), config=config,
156
      cache_dir=cache_dir)
157

158
  if FLAGS.target_text:
159
    # If we read the entire dataset from text, select the slice to encode
160
    start = PROC_SIZE * FLAGS.slice
161
    end = PROC_SIZE * (FLAGS.slice + 1)
162
    if FLAGS.slice == 14:
163
      end = 9999999
164
    input_targets = targets_decoded_pt[start:end]
165
  else:
166
    # the targets were decoded above so just use the ones that were decoded
167
    input_targets = targets_decoded_pt
168
  encoding = tokenizer(
169
      input_targets,
170
      return_tensors='tf',
171
      padding=True,
172
      truncation=True,
173
      max_length=512)
174

175
  train_dataset = tf.data.Dataset.from_tensor_slices((
176
      dict(encoding),
177
  ))
178
  batch_size = 256
179
  if FLAGS.clf_inputs and FLAGS.clf_targets:
180
    # multiling model is larger
181
    batch_size = 128
182
  train_dataset = train_dataset.batch(batch_size)
183
  logits = model.predict(train_dataset)
184

185
  probs = softmax(logits.logits, axis=1)
186

187
  clf_score_name = FLAGS.save_dir + '/CLR_scores_' + str(FLAGS.slice) + '.csv'
188
  with tf.io.gfile.GFile(clf_score_name, 'w') as f:
189
    writer = csv.writer(f)
190
    for p in probs:
191
      writer.writerow([p[1]])
192

193

194
if __name__ == '__main__':
195
  app.run(main)
196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.