google-research
195 строк · 5.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Utilities for language."""
17# pylint: disable=not-an-iterable
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import random
24import re
25
26import numpy as np
27
28
29_RELATION_SYNONYMS = {
30'on the left side': ['left', 'on the left'],
31'on the right side': ['right', 'on the right'],
32'in front of': ['front of']
33}
34_MATERIAL_SYNONYMS = {
35'matte': ['rubber', ''],
36'rubber': ['matte', ''],
37'shiny': ['metallic', ''],
38'metallic': ['shiny', '']
39}
40_OBJECT_SYNONYMS = {
41'object': ['sphere', 'object', 'thing'],
42'sphere': ['object', 'ball', 'thing'],
43'ball': ['sphere', 'object', 'thing'],
44'objects': ['spheres', 'objects', 'things'],
45'spheres': ['objects', 'balls', 'things'],
46'balls': ['spheres', 'objects', 'things']
47}
48_ADJECTIVE_SYNONYMS = {'any': ['']}
49_MISC_SYNONYMS = {'are': ['is']}
50
51_CLEVR_SYNONYM_TABLES = [
52_RELATION_SYNONYMS, _MATERIAL_SYNONYMS, _OBJECT_SYNONYMS,
53_ADJECTIVE_SYNONYMS, _MISC_SYNONYMS
54]
55
56_COLORS = [
57{
58'red': ['']
59},
60{
61'blue': ['']
62},
63{
64'cyan': ['']
65},
66{
67'purple': ['']
68},
69{
70'green': ['']
71},
72]
73
74_OTHER_COLORS = {
75'red': ['blue', 'cyan', 'purple', 'green'],
76'blue': ['red', 'cyan', 'purple', 'green'],
77'cyan': ['blue', 'red', 'purple', 'green'],
78'purple': ['blue', 'cyan', 'red', 'green'],
79'green': ['blue', 'cyan', 'purple', 'red'],
80}
81
82_OTHER_DIRECTIONS = {
83'left': ['right'],
84'right': ['left'],
85'front': ['behind'],
86'behind': ['front'],
87}
88
89
90def get_vocab_path(cfg):
91"""Get path to the list of vocabularies."""
92vocab_path = None
93if not vocab_path:
94vocab_path = cfg.vocab_path
95return vocab_path
96
97
98def instruction_type(instruction):
99if len(instruction) < 40:
100return 'unary'
101else:
102return 'regular'
103
104
105def pad_to_max_length(data, max_l=None, eos_token=0):
106"""Pad a list of sequence to the maximum length."""
107eos = eos_token
108if not max_l:
109max_l = -1
110for p in data:
111max_l = max(max_l, len(p))
112data_padded = []
113for p in data:
114if len(p) == max_l:
115data_padded.append(list(p))
116else:
117p = list(p) + [eos] * (max_l - len(p))
118data_padded.append(p)
119return np.array(data_padded)
120
121
122def pad_sequence(data, max_l=None, eos_token=0):
123"""Pad a sequence to max_l with eos_token."""
124eos = eos_token
125if len(data) == max_l:
126return np.array(data)
127elif len(data) > max_l:
128raise ValueError('data longer than max_l')
129else:
130data = list(data) + [eos] * (max_l - len(data))
131return np.array(data)
132
133
134def paraphrase_sentence(text, synonym_tables=None, delete_color=False, k=2):
135"""Paraphrase a sentence.
136
137Args:
138text: text to be paraphrased
139synonym_tables: a table that contains synonyms for all the words
140delete_color: whether to delete colors from sentences
141k: number of words to replace
142
143Returns:
144paraphrased text
145"""
146if not synonym_tables:
147synonym_tables = _CLEVR_SYNONYM_TABLES
148tables = random.sample(synonym_tables, k)
149if delete_color and random.uniform(0, 1) < 0.5:
150tables = random.sample(_COLORS, 5)
151subed = False
152for t in tables:
153if subed:
154break
155for w in t:
156if w in text:
157text = re.sub(w, random.choice(t[w]), text)
158subed = True
159else:
160for t in tables:
161for w in t:
162if w in text:
163text = re.sub(w, random.choice(t[w]), text)
164return text
165
166
167def negate_unary_sentence(text):
168"""Negate a instruction involving a single object."""
169words = text.split(' ')
170mutate_candiate = {}
171for i, w in enumerate(words):
172if w in _OTHER_COLORS:
173mutate_candiate['color'] = (i, w)
174elif w in _OTHER_DIRECTIONS:
175mutate_candiate['direction'] = (i, w)
176toss = random.random()
177if toss < 0.33 and 'color' in mutate_candiate:
178i, color = mutate_candiate['color']
179new_color = random.choice(_OTHER_COLORS[color])
180words[i] = new_color
181elif 0.33 < random.random() < 0.66 and 'direction' in mutate_candiate:
182i, direction = mutate_candiate['direction']
183new_direction = random.choice(_OTHER_DIRECTIONS[direction])
184words[i] = new_direction
185elif 'direction' in mutate_candiate and 'color' in mutate_candiate:
186i, color = mutate_candiate['color']
187new_color = random.choice(_OTHER_COLORS[color])
188words[i] = new_color
189i, direction = mutate_candiate['direction']
190new_direction = random.choice(_OTHER_DIRECTIONS[direction])
191words[i] = new_direction
192else:
193return None
194mutated_text = ' '.join(words)
195return mutated_text
196