google-research

Форк
0
/
cfq_scan_tasks.py 
108 строк · 3.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""CFQ and SCAN tasks for T5."""
17
import string
18

19
import t5.data
20
from t5.data import postprocessors as t5_postprocessors
21
from t5.evaluation import metrics as t5_metrics
22
import tensorflow as tf
23
import tensorflow_datasets as tfds
24

25
TaskRegistry = t5.data.TaskRegistry
26
TfdsTask = t5.data.TfdsTask
27

28

29
def tokenize_punctuation(text):
30
  text = map(lambda c: ' %s ' % c if c in string.punctuation else c, text)
31
  return ' '.join(''.join(text).split())
32

33

34
def preprocess_sparql(query):
35
  """Do various preprocessing on the SPARQL query."""
36
  # Tokenize braces.
37
  query = query.replace('count(*)', 'count ( * )')
38

39
  tokens = []
40
  for token in query.split():
41
    # Replace 'ns:' prefixes.
42
    if token.startswith('ns:'):
43
      token = token[3:]
44
    # Replace mid prefixes.
45
    if token.startswith('m.'):
46
      token = 'm_' + token[2:]
47
    tokens.append(token)
48

49
  return ' '.join(tokens).replace('\\n', ' ')
50

51

52
def cfq_preprocess(dataset):
53
  """Select input/target features and add prefix to input."""
54

55
  def compute_inputs_and_targets(inputs, targets):
56
    inputs = tf.compat.as_text(inputs.numpy())
57
    inputs = tokenize_punctuation(inputs)
58
    targets = tf.compat.as_text(targets.numpy())
59
    targets = preprocess_sparql(targets)
60

61
    return inputs, targets
62

63
  def map_fn(x):
64
    inputs, targets = tf.py_function(
65
        compute_inputs_and_targets,
66
        inp=[x['question'], x['query']],
67
        Tout=[tf.string, tf.string])
68
    return {
69
        # The reshape is necessary as otherwise the tensor has unknown rank.
70
        'inputs': tf.reshape(inputs, shape=[]),
71
        'targets': tf.reshape(targets, shape=[])
72
    }
73

74
  return dataset.map(map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
75

76

77
for split in ['mcd1', 'mcd2', 'mcd3', '2m', 'random']:
78
  TaskRegistry.add(
79
      f'cfq_{split}',
80
      TfdsTask,
81
      tfds_name=f'cfq/{split}:1.2.0',
82
      text_preprocessor=cfq_preprocess,
83
      postprocess_fn=t5_postprocessors.lower_text,
84
      metric_fns=[t5_metrics.sequence_accuracy])
85

86

87
def scan_preprocess(dataset):
88
  """Select input/target features and add prefix to input."""
89

90
  def scan_map(sample):
91
    return {
92
        'inputs':
93
            tf.strings.join(['executescancommand:', sample['commands']], ' '),
94
        'targets':
95
            sample['actions']
96
    }
97

98
  return dataset.map(scan_map, num_parallel_calls=tf.data.experimental.AUTOTUNE)
99

100

101
for split in tfds.builder('scan').builder_configs.keys():
102
  TaskRegistry.add(
103
      f'scan_{split}',
104
      TfdsTask,
105
      tfds_name=f'scan/{split}:1.1.1',
106
      text_preprocessor=scan_preprocess,
107
      postprocess_fn=t5_postprocessors.lower_text,
108
      metric_fns=[t5_metrics.sequence_accuracy])
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.