google-research
326 строк · 12.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Creates common crawl dateset in TFRecord format."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import csv
24import json
25import operator
26import os
27from absl import app
28from absl import flags
29
30import attr
31import numpy as np
32import tensorflow.compat.v1 as tf # tf
33
34from seq2act.data_generation import common
35from seq2act.data_generation import config
36from seq2act.data_generation import proto_utils
37from seq2act.data_generation import string_utils
38
39_NUM_SHARDS_DEFAULT = 10
40GOLD_NUM_IN_SHARD0 = 700
41
42FLAGS = flags.FLAGS
43flags.DEFINE_integer(
44'num_shards', _NUM_SHARDS_DEFAULT,
45'The number of sharded files to save the created dataset.')
46flags.DEFINE_enum(
47'sharding', '700-golden-in-shard0', ['hash', '700-golden-in-shard0'],
48'The way to split data to sharding')
49flags.DEFINE_string(
50'input_csv_file', None, 'Input CSV file of labeled data.')
51flags.DEFINE_string(
52'input_instruction_json_file', None,
53'Input Json file of downloaded instructions.')
54flags.DEFINE_string(
55'output_dir', None,
56'Output directory for generated tf example proto data.')
57
58
59# Debug purpose only
60counters = collections.Counter()
61id_batch_map = {} # dict of {Task_id: Batch num where the task comes from}
62one_agreement_ids = []
63# Pick out some examples out for manual check, {"reason": list_of_examples}
64chosen_examples = collections.defaultdict(list)
65# Stats on some features, {"interesting_feature_name": {"value": count}}
66distributions = collections.defaultdict(collections.Counter)
67
68
69@attr.s
70@attr.s
71class Action(object):
72verb_type = attr.ib()
73verb_start_pos = attr.ib()
74verb_end_pos = attr.ib()
75object_desc_start_pos = attr.ib()
76object_desc_end_pos = attr.ib()
77input_content_start_pos = attr.ib()
78input_content_end_pos = attr.ib()
79
80
81def _annotation_to_actions(annotation_str):
82"""Splits the annotated actions to list of Actions for easy to use."""
83# Example: [CLICK-28:31-42:81-0:0]->[CLICK-248:251-254:337-0:0]
84actions = []
85action_str_list = annotation_str.split('->')
86for action_str in action_str_list:
87action = Action()
88action_str = action_str[1:-1]
89parts = action_str.split('-')
90action.verb_type = parts[0]
91(action.verb_start_pos,
92action.verb_end_pos) = [int(x) for x in parts[1].split(':')]
93(action.object_desc_start_pos,
94action.object_desc_end_pos) = [int(x) for x in parts[2].split(':')]
95(action.input_content_start_pos,
96action.input_content_end_pos) = [int(x) for x in parts[3].split(':')]
97actions.append(action)
98return actions
99
100
101def _task_to_features_dict(task, do_stats):
102"""Converts one task to features dict.
103
104Args:
105task: Task instance, the task to be converted.
106do_stats: whether do stats on this task. Set it to False for debug or test.
107Returns:
108features: dict of (string, np.array) which contains columns and values:
109features = {
110'instruction_str': string of answer, np string array, shape = (1,)
111'instruction_word_id_seq': word id sequence of question, np string array,
112shape = (word_num,)
113
114'verb_id_seq': word id of verb, np int array, shape = (action_num,)
115'obj_desc_position_seq': index of word id of object in answer, np int
116array, shape = (actions * 2,)
117'input_str_position_seq': additional info of the action, shape =
118(action_num * 2,)
119
120'raters_count_per_task': shape = (1,)
121'agreement_count': shape = (1,)
122Raises:
123ValueError: raise error when fail to parse actions of tasks.
124"""
125answer = task['instruction'].lower()
126features = {}
127features['instruction_str'] = np.array([answer], dtype=np.string_)
128tokens, _ = string_utils.tokenize_to_ids(answer)
129features['instruction_word_id_seq'] = np.array(tokens, dtype=np.int64)
130
131verb_id_seq = []
132verb_str_position_seq = []
133obj_desc_position_seq = []
134input_str_position_seq = []
135
136for action in task['actions']:
137try:
138verb_id = common.ActionTypes[action.verb_type.upper().strip()]
139except KeyError:
140raise ValueError('Verb "%s" cannot be recognized.' % action.verb_type)
141if verb_id == common.ActionTypes.OTHERS:
142verb = answer[action.verb_start_pos: action.verb_end_pos].strip().lower()
143verb_id = common.VERB_ID_MAP.get(verb, common.ActionTypes.OTHERS)
144verb_id_seq.append(verb_id.value)
145
146verb_str_position_seq.extend(string_utils.get_token_pos_from_char_pos(
147answer, action.verb_start_pos, action.verb_end_pos))
148if do_stats and task['agreement-count'] >= 2:
149distributions['longest_verb_str'][
150verb_str_position_seq[-1] - verb_str_position_seq[-2]] += 1
151
152obj_desc_position_seq.extend(string_utils.get_token_pos_from_char_pos(
153answer, action.object_desc_start_pos, action.object_desc_end_pos))
154if do_stats and task['agreement-count'] >= 2:
155distributions['longest_obj_desc'][
156obj_desc_position_seq[-1] - obj_desc_position_seq[-2]] += 1
157
158if not (action.input_content_start_pos == 0 and
159action.input_content_end_pos == 0):
160input_str_position_seq.extend(string_utils.get_token_pos_from_char_pos(
161answer, action.input_content_start_pos, action.input_content_end_pos))
162if do_stats and task['agreement-count'] >= 2:
163distributions['longest_input_str'][
164input_str_position_seq[-1] - input_str_position_seq[-2]] += 1
165else:
166input_str_position_seq.extend([config.LABEL_DEFAULT_VALUE_INT] * 2)
167
168features['verb_id_seq'] = np.array(verb_id_seq, dtype=np.int64)
169features['verb_str_position_seq'] = np.array(verb_str_position_seq,
170dtype=np.int64)
171features['obj_desc_position_seq'] = np.array(obj_desc_position_seq,
172dtype=np.int64)
173features['input_str_position_seq'] = np.array(input_str_position_seq,
174dtype=np.int64)
175
176features['agreement_count'] = np.array([task['agreement-count']],
177dtype=np.int64)
178
179if do_stats:
180distributions['step_num'][len(task['actions'])] += 1
181distributions['longest_instruction'][len(tokens)] += 1
182counters['total_verb_refs'] += len(verb_id_seq)
183counters['total_obj_refs'] += len(obj_desc_position_seq) / 2
184counters['total_input_refs'] += (
185(len(input_str_position_seq) -
186input_str_position_seq.count(config.LABEL_DEFAULT_VALUE_INT)) / 2)
187for verb in common.ActionTypes:
188if verb.value in verb_id_seq:
189counters['Instructions contain %s in verbs' % verb.name] += 1
190if input_str_position_seq.count(config.LABEL_DEFAULT_VALUE_INT) != len(
191input_str_position_seq):
192counters['Instructions contain INPUT Content'] += 1
193if ' and then ' in answer:
194chosen_examples['instruction_contains_and-then'].append(answer)
195if ' after ' in answer:
196chosen_examples['instruction_contains_after'].append(answer)
197if '. ' in answer:
198counters['instruction_contains_dot'] += 1
199if ', ' in answer:
200counters['instruction_contains_comma'] += 1
201
202return features
203
204
205def _write_tasks_to_tf_example(id_tasks_dict, output_dir, num_shards, sharding):
206"""Writes tasks as tf.Example.
207
208Args:
209id_tasks_dict: dict of task_id and list of Tasks.
210output_dir: string, the full path of outupt folder.
211num_shards: int, number of shards of output.
212sharding: from flag sharding enum, how to sharding the data.
213"""
214tfrecord_writers = []
215for shard in range(num_shards):
216tfrecord_writers.append(tf.python_io.TFRecordWriter(
217os.path.join(output_dir, 'commoncrawl_%d.tfrecord' % shard)))
218
219def write_task(task, shard_id):
220try:
221features = _task_to_features_dict(task, do_stats=True)
222except ValueError:
223counters['ValueError'] += 1
224else:
225tfproto = proto_utils.features_to_tf_example(features)
226tfrecord_writers[shard_id].write(tfproto.SerializeToString())
227counters['examples_count_in_dataset'] += 1
228
229# Sharing mode
230if sharding == 'hash':
231for task_id, tasks in id_tasks_dict.items():
232shard_id = hash(task_id) % num_shards
233for task in tasks:
234write_task(task, shard_id)
235
236else: # when sharding == '700-golden-in-shard0'
237# For testing purpose, put 700 100% agreement tasks to shard_0,
238# and then put the rest tasks to shard 1~9
239testing_count = 0
240for task_id, tasks in id_tasks_dict.items():
241if (testing_count < GOLD_NUM_IN_SHARD0 and
242tasks[0]['agreement-count'] == len(tasks) and len(tasks) >= 3):
243for task in tasks:
244write_task(tasks[0], shard_id=0)
245testing_count += 1
246else:
247shard_id = hash(task_id) % (num_shards -1) + 1
248for task in tasks:
249write_task(task, shard_id)
250
251
252def _read_tasks(input_csv_file, input_instruction_json_file):
253"""Reads rows from CSV file containing the annotations."""
254
255# We use `index+url` as the ID of an instruction
256def get_task_id(index, url):
257return '%s+%s' % (str(index), url)
258
259id_instruction_dict = {}
260with open(input_instruction_json_file, 'r') as f:
261for line in f:
262if line.strip():
263json_dict = json.loads(line)
264id_instruction_dict[get_task_id(
265json_dict['index'], json_dict['url'])] = json_dict['instructions']
266
267instruction_found = 0
268instruction_not_found = 0
269id_tasks_dict = collections.defaultdict(list)
270with open(input_csv_file, 'r') as f:
271reader = csv.DictReader(f)
272for row in reader:
273task_id = get_task_id(row['index'], row['url'])
274row['actions'] = _annotation_to_actions(row['annotation'])
275row['agreement-count'] = int(row['agreement-count'])
276if task_id in id_instruction_dict:
277row['instruction'] = id_instruction_dict[task_id]
278id_tasks_dict[task_id].append(row)
279instruction_found += 1
280else:
281instruction_not_found += 1
282
283if instruction_not_found == 0:
284print('All %s instructions match with annotations successfully.' %
285instruction_found)
286else:
287print('%s instructions match with annotations successfully.' %
288instruction_found)
289print('Warning: can not find instructions for %s annotations, probably you '
290'have not downloaded all the WARC files.' % instruction_not_found)
291return id_tasks_dict
292
293
294def _generate_commoncrawl_dataset():
295"""Generates commoncrawl dataset with the annotations."""
296assert FLAGS.input_csv_file.endswith('.csv')
297id_tasks_dict = _read_tasks(FLAGS.input_csv_file,
298FLAGS.input_instruction_json_file)
299
300_write_tasks_to_tf_example(id_tasks_dict, FLAGS.output_dir,
301FLAGS.num_shards, FLAGS.sharding)
302
303def sort_dict_by_key(the_dict):
304return sorted(the_dict.items(), key=operator.itemgetter(0))
305
306with open(os.path.join(FLAGS.output_dir, 'stats.txt'), 'w+') as stat_file:
307stat_file.write('stat_fix_dict: %s\n' % string_utils.stat_fix_dict)
308for key, count in sort_dict_by_key(counters):
309stat_file.write('%s: %s\n' % (key, count))
310for key, examples in sort_dict_by_key(chosen_examples):
311stat_file.write('%s: %s\n' % (key, len(examples)))
312for key, distribution in distributions.items():
313stat_file.write('%s: %s\n' % (key, sort_dict_by_key(distribution)))
314
315for key, examples in chosen_examples.items():
316with open(os.path.join(FLAGS.output_dir, key), 'w+') as writer:
317writer.write('\n'.join(examples))
318
319
320def main(_):
321_generate_commoncrawl_dataset()
322
323
324if __name__ == '__main__':
325FLAGS.set_default('logtostderr', True)
326app.run(main)
327