google-research

Форк
0
/
create_commoncrawl_dataset.py 
326 строк · 12.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Creates common crawl dateset in TFRecord format."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import csv
24
import json
25
import operator
26
import os
27
from absl import app
28
from absl import flags
29

30
import attr
31
import numpy as np
32
import tensorflow.compat.v1 as tf  # tf
33

34
from seq2act.data_generation import common
35
from seq2act.data_generation import config
36
from seq2act.data_generation import proto_utils
37
from seq2act.data_generation import string_utils
38

39
_NUM_SHARDS_DEFAULT = 10
40
GOLD_NUM_IN_SHARD0 = 700
41

42
FLAGS = flags.FLAGS
43
flags.DEFINE_integer(
44
    'num_shards', _NUM_SHARDS_DEFAULT,
45
    'The number of sharded files to save the created dataset.')
46
flags.DEFINE_enum(
47
    'sharding', '700-golden-in-shard0', ['hash', '700-golden-in-shard0'],
48
    'The way to split data to sharding')
49
flags.DEFINE_string(
50
    'input_csv_file', None, 'Input CSV file of labeled data.')
51
flags.DEFINE_string(
52
    'input_instruction_json_file', None,
53
    'Input Json file of downloaded instructions.')
54
flags.DEFINE_string(
55
    'output_dir', None,
56
    'Output directory for generated tf example proto data.')
57

58

59
# Debug purpose only
60
counters = collections.Counter()
61
id_batch_map = {}  # dict of {Task_id: Batch num where the task comes from}
62
one_agreement_ids = []
63
# Pick out some examples out for manual check, {"reason": list_of_examples}
64
chosen_examples = collections.defaultdict(list)
65
# Stats on some features, {"interesting_feature_name": {"value": count}}
66
distributions = collections.defaultdict(collections.Counter)
67

68

69
@attr.s
70
@attr.s
71
class Action(object):
72
  verb_type = attr.ib()
73
  verb_start_pos = attr.ib()
74
  verb_end_pos = attr.ib()
75
  object_desc_start_pos = attr.ib()
76
  object_desc_end_pos = attr.ib()
77
  input_content_start_pos = attr.ib()
78
  input_content_end_pos = attr.ib()
79

80

81
def _annotation_to_actions(annotation_str):
82
  """Splits the annotated actions to list of Actions for easy to use."""
83
  # Example: [CLICK-28:31-42:81-0:0]->[CLICK-248:251-254:337-0:0]
84
  actions = []
85
  action_str_list = annotation_str.split('->')
86
  for action_str in action_str_list:
87
    action = Action()
88
    action_str = action_str[1:-1]
89
    parts = action_str.split('-')
90
    action.verb_type = parts[0]
91
    (action.verb_start_pos,
92
     action.verb_end_pos) = [int(x) for x in parts[1].split(':')]
93
    (action.object_desc_start_pos,
94
     action.object_desc_end_pos) = [int(x) for x in parts[2].split(':')]
95
    (action.input_content_start_pos,
96
     action.input_content_end_pos) = [int(x) for x in parts[3].split(':')]
97
    actions.append(action)
98
  return actions
99

100

101
def _task_to_features_dict(task, do_stats):
102
  """Converts one task to features dict.
103

104
  Args:
105
    task: Task instance, the task to be converted.
106
    do_stats: whether do stats on this task. Set it to False for debug or test.
107
  Returns:
108
    features: dict of (string, np.array) which contains columns and values:
109
    features = {
110
      'instruction_str': string of answer, np string array, shape = (1,)
111
      'instruction_word_id_seq': word id sequence of question, np string array,
112
          shape = (word_num,)
113

114
      'verb_id_seq': word id of verb, np int array, shape = (action_num,)
115
      'obj_desc_position_seq': index of word id of object in answer, np int
116
          array, shape = (actions * 2,)
117
      'input_str_position_seq': additional info of the action, shape =
118
          (action_num * 2,)
119

120
      'raters_count_per_task': shape = (1,)
121
      'agreement_count': shape = (1,)
122
  Raises:
123
    ValueError: raise error when fail to parse actions of tasks.
124
  """
125
  answer = task['instruction'].lower()
126
  features = {}
127
  features['instruction_str'] = np.array([answer], dtype=np.string_)
128
  tokens, _ = string_utils.tokenize_to_ids(answer)
129
  features['instruction_word_id_seq'] = np.array(tokens, dtype=np.int64)
130

131
  verb_id_seq = []
132
  verb_str_position_seq = []
133
  obj_desc_position_seq = []
134
  input_str_position_seq = []
135

136
  for action in task['actions']:
137
    try:
138
      verb_id = common.ActionTypes[action.verb_type.upper().strip()]
139
    except KeyError:
140
      raise ValueError('Verb "%s" cannot be recognized.' % action.verb_type)
141
    if verb_id == common.ActionTypes.OTHERS:
142
      verb = answer[action.verb_start_pos: action.verb_end_pos].strip().lower()
143
      verb_id = common.VERB_ID_MAP.get(verb, common.ActionTypes.OTHERS)
144
    verb_id_seq.append(verb_id.value)
145

146
    verb_str_position_seq.extend(string_utils.get_token_pos_from_char_pos(
147
        answer, action.verb_start_pos, action.verb_end_pos))
148
    if do_stats and task['agreement-count'] >= 2:
149
      distributions['longest_verb_str'][
150
          verb_str_position_seq[-1] - verb_str_position_seq[-2]] += 1
151

152
    obj_desc_position_seq.extend(string_utils.get_token_pos_from_char_pos(
153
        answer, action.object_desc_start_pos, action.object_desc_end_pos))
154
    if do_stats and task['agreement-count'] >= 2:
155
      distributions['longest_obj_desc'][
156
          obj_desc_position_seq[-1] - obj_desc_position_seq[-2]] += 1
157

158
    if not (action.input_content_start_pos == 0 and
159
            action.input_content_end_pos == 0):
160
      input_str_position_seq.extend(string_utils.get_token_pos_from_char_pos(
161
          answer, action.input_content_start_pos, action.input_content_end_pos))
162
      if do_stats and task['agreement-count'] >= 2:
163
        distributions['longest_input_str'][
164
            input_str_position_seq[-1] - input_str_position_seq[-2]] += 1
165
    else:
166
      input_str_position_seq.extend([config.LABEL_DEFAULT_VALUE_INT] * 2)
167

168
  features['verb_id_seq'] = np.array(verb_id_seq, dtype=np.int64)
169
  features['verb_str_position_seq'] = np.array(verb_str_position_seq,
170
                                               dtype=np.int64)
171
  features['obj_desc_position_seq'] = np.array(obj_desc_position_seq,
172
                                               dtype=np.int64)
173
  features['input_str_position_seq'] = np.array(input_str_position_seq,
174
                                                dtype=np.int64)
175

176
  features['agreement_count'] = np.array([task['agreement-count']],
177
                                         dtype=np.int64)
178

179
  if do_stats:
180
    distributions['step_num'][len(task['actions'])] += 1
181
    distributions['longest_instruction'][len(tokens)] += 1
182
    counters['total_verb_refs'] += len(verb_id_seq)
183
    counters['total_obj_refs'] += len(obj_desc_position_seq) / 2
184
    counters['total_input_refs'] += (
185
        (len(input_str_position_seq) -
186
         input_str_position_seq.count(config.LABEL_DEFAULT_VALUE_INT)) / 2)
187
    for verb in common.ActionTypes:
188
      if verb.value in verb_id_seq:
189
        counters['Instructions contain %s in verbs' % verb.name] += 1
190
    if input_str_position_seq.count(config.LABEL_DEFAULT_VALUE_INT) != len(
191
        input_str_position_seq):
192
      counters['Instructions contain INPUT Content'] += 1
193
    if ' and then ' in answer:
194
      chosen_examples['instruction_contains_and-then'].append(answer)
195
    if ' after ' in answer:
196
      chosen_examples['instruction_contains_after'].append(answer)
197
    if '. ' in answer:
198
      counters['instruction_contains_dot'] += 1
199
    if ', ' in answer:
200
      counters['instruction_contains_comma'] += 1
201

202
  return features
203

204

205
def _write_tasks_to_tf_example(id_tasks_dict, output_dir, num_shards, sharding):
206
  """Writes tasks as tf.Example.
207

208
  Args:
209
    id_tasks_dict: dict of task_id and list of Tasks.
210
    output_dir: string, the full path of outupt folder.
211
    num_shards: int, number of shards of output.
212
    sharding: from flag sharding enum, how to sharding the data.
213
  """
214
  tfrecord_writers = []
215
  for shard in range(num_shards):
216
    tfrecord_writers.append(tf.python_io.TFRecordWriter(
217
        os.path.join(output_dir, 'commoncrawl_%d.tfrecord' % shard)))
218

219
  def write_task(task, shard_id):
220
    try:
221
      features = _task_to_features_dict(task, do_stats=True)
222
    except ValueError:
223
      counters['ValueError'] += 1
224
    else:
225
      tfproto = proto_utils.features_to_tf_example(features)
226
      tfrecord_writers[shard_id].write(tfproto.SerializeToString())
227
      counters['examples_count_in_dataset'] += 1
228

229
  # Sharing mode
230
  if sharding == 'hash':
231
    for task_id, tasks in id_tasks_dict.items():
232
      shard_id = hash(task_id) % num_shards
233
      for task in tasks:
234
        write_task(task, shard_id)
235

236
  else:  # when sharding == '700-golden-in-shard0'
237
    # For testing purpose, put 700 100% agreement tasks to shard_0,
238
    # and then put the rest tasks to shard 1~9
239
    testing_count = 0
240
    for task_id, tasks in id_tasks_dict.items():
241
      if (testing_count < GOLD_NUM_IN_SHARD0 and
242
          tasks[0]['agreement-count'] == len(tasks) and len(tasks) >= 3):
243
        for task in tasks:
244
          write_task(tasks[0], shard_id=0)
245
        testing_count += 1
246
      else:
247
        shard_id = hash(task_id) % (num_shards -1) + 1
248
        for task in tasks:
249
          write_task(task, shard_id)
250

251

252
def _read_tasks(input_csv_file, input_instruction_json_file):
253
  """Reads rows from CSV file containing the annotations."""
254

255
  # We use `index+url` as the ID of an instruction
256
  def get_task_id(index, url):
257
    return '%s+%s' % (str(index), url)
258

259
  id_instruction_dict = {}
260
  with open(input_instruction_json_file, 'r') as f:
261
    for line in f:
262
      if line.strip():
263
        json_dict = json.loads(line)
264
        id_instruction_dict[get_task_id(
265
            json_dict['index'], json_dict['url'])] = json_dict['instructions']
266

267
  instruction_found = 0
268
  instruction_not_found = 0
269
  id_tasks_dict = collections.defaultdict(list)
270
  with open(input_csv_file, 'r') as f:
271
    reader = csv.DictReader(f)
272
    for row in reader:
273
      task_id = get_task_id(row['index'], row['url'])
274
      row['actions'] = _annotation_to_actions(row['annotation'])
275
      row['agreement-count'] = int(row['agreement-count'])
276
      if task_id in id_instruction_dict:
277
        row['instruction'] = id_instruction_dict[task_id]
278
        id_tasks_dict[task_id].append(row)
279
        instruction_found += 1
280
      else:
281
        instruction_not_found += 1
282

283
  if instruction_not_found == 0:
284
    print('All %s instructions match with annotations successfully.' %
285
          instruction_found)
286
  else:
287
    print('%s instructions match with annotations successfully.' %
288
          instruction_found)
289
    print('Warning: can not find instructions for %s annotations, probably you '
290
          'have not downloaded all the WARC files.' % instruction_not_found)
291
  return id_tasks_dict
292

293

294
def _generate_commoncrawl_dataset():
295
  """Generates commoncrawl dataset with the annotations."""
296
  assert FLAGS.input_csv_file.endswith('.csv')
297
  id_tasks_dict = _read_tasks(FLAGS.input_csv_file,
298
                              FLAGS.input_instruction_json_file)
299

300
  _write_tasks_to_tf_example(id_tasks_dict, FLAGS.output_dir,
301
                             FLAGS.num_shards, FLAGS.sharding)
302

303
  def sort_dict_by_key(the_dict):
304
    return sorted(the_dict.items(), key=operator.itemgetter(0))
305

306
  with open(os.path.join(FLAGS.output_dir, 'stats.txt'), 'w+') as stat_file:
307
    stat_file.write('stat_fix_dict: %s\n' % string_utils.stat_fix_dict)
308
    for key, count in sort_dict_by_key(counters):
309
      stat_file.write('%s: %s\n' % (key, count))
310
    for key, examples in sort_dict_by_key(chosen_examples):
311
      stat_file.write('%s: %s\n' % (key, len(examples)))
312
    for key, distribution in distributions.items():
313
      stat_file.write('%s: %s\n' % (key, sort_dict_by_key(distribution)))
314

315
  for key, examples in chosen_examples.items():
316
    with open(os.path.join(FLAGS.output_dir, key), 'w+') as writer:
317
      writer.write('\n'.join(examples))
318

319

320
def main(_):
321
  _generate_commoncrawl_dataset()
322

323

324
if __name__ == '__main__':
325
  FLAGS.set_default('logtostderr', True)
326
  app.run(main)
327

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.