google-research
108 строк · 3.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""CFQ and SCAN tasks for T5."""
17import string18
19import t5.data20from t5.data import postprocessors as t5_postprocessors21from t5.evaluation import metrics as t5_metrics22import tensorflow as tf23import tensorflow_datasets as tfds24
25TaskRegistry = t5.data.TaskRegistry26TfdsTask = t5.data.TfdsTask27
28
29def tokenize_punctuation(text):30text = map(lambda c: ' %s ' % c if c in string.punctuation else c, text)31return ' '.join(''.join(text).split())32
33
34def preprocess_sparql(query):35"""Do various preprocessing on the SPARQL query."""36# Tokenize braces.37query = query.replace('count(*)', 'count ( * )')38
39tokens = []40for token in query.split():41# Replace 'ns:' prefixes.42if token.startswith('ns:'):43token = token[3:]44# Replace mid prefixes.45if token.startswith('m.'):46token = 'm_' + token[2:]47tokens.append(token)48
49return ' '.join(tokens).replace('\\n', ' ')50
51
52def cfq_preprocess(dataset):53"""Select input/target features and add prefix to input."""54
55def compute_inputs_and_targets(inputs, targets):56inputs = tf.compat.as_text(inputs.numpy())57inputs = tokenize_punctuation(inputs)58targets = tf.compat.as_text(targets.numpy())59targets = preprocess_sparql(targets)60
61return inputs, targets62
63def map_fn(x):64inputs, targets = tf.py_function(65compute_inputs_and_targets,66inp=[x['question'], x['query']],67Tout=[tf.string, tf.string])68return {69# The reshape is necessary as otherwise the tensor has unknown rank.70'inputs': tf.reshape(inputs, shape=[]),71'targets': tf.reshape(targets, shape=[])72}73
74return dataset.map(map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)75
76
77for split in ['mcd1', 'mcd2', 'mcd3', '2m', 'random']:78TaskRegistry.add(79f'cfq_{split}',80TfdsTask,81tfds_name=f'cfq/{split}:1.2.0',82text_preprocessor=cfq_preprocess,83postprocess_fn=t5_postprocessors.lower_text,84metric_fns=[t5_metrics.sequence_accuracy])85
86
87def scan_preprocess(dataset):88"""Select input/target features and add prefix to input."""89
90def scan_map(sample):91return {92'inputs':93tf.strings.join(['executescancommand:', sample['commands']], ' '),94'targets':95sample['actions']96}97
98return dataset.map(scan_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)99
100
101for split in tfds.builder('scan').builder_configs.keys():102TaskRegistry.add(103f'scan_{split}',104TfdsTask,105tfds_name=f'scan/{split}:1.1.1',106text_preprocessor=scan_preprocess,107postprocess_fn=t5_postprocessors.lower_text,108metric_fns=[t5_metrics.sequence_accuracy])109