google-research

Форк
0
/
convert_character_mining_data.py 
227 строк · 6.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Converts character mining json data to Contrack text format."""
17

18
import json
19
import os
20
import re
21
from typing import List, Sequence
22

23
from absl import app
24
from absl import flags
25
from absl import logging
26

27
import tensorflow as tf
28

29
flags.DEFINE_string(
30
    'input_dir', '',
31
    'Path to the directory containing the input data json files.')
32

33
flags.DEFINE_string(
34
    'output_dir', '',
35
    'Path to the directory where the conversations are written.')
36

37
FLAGS = flags.FLAGS
38

39
KNOWN_CHARACTERS = frozenset([
40
    'Ross Geller',
41
    'Rachel Green',
42
    'Chandler Bing',
43
    'Monica Geller',
44
    'Joey Tribbiani',
45
    'Phoebe Buffay',
46
    'Emily Waltham',
47
    'Richard Burke',
48
    'Carol Willick',
49
    'Ben Geller',
50
    'Peter Becker',
51
    'Judy Geller',
52
    'Barry Farber',
53
    'Jack Geller',
54
    'Kate Miller'
55
])
56

57
ENTITY_GROUPS = {'#GENERAL#': 'general', '#OTHER#': 'other'}
58

59
GENDER = {
60
    'ross': 'm',
61
    'rachel': 'f',
62
    'chandler': 'm',
63
    'monica': 'f',
64
    'joey': 'm',
65
    'phoebe': 'f',
66
    'emily': 'f',
67
    'richard': 'm',
68
    'carol': 'f',
69
    'ben': 'm',
70
    'peter': 'm',
71
    'judy': 'f',
72
    'barry': 'm',
73
    'jack': 'm',
74
    'kate': 'f',
75
    'general': 'u',
76
    'other': 'u'
77
}
78

79

80
class Turn(object):
81
  """Represents a turn in a Contrack conversation."""
82

83
  def __init__(self, speaker, msg, enrefs):
84
    self.speaker = speaker
85
    self.msg = msg
86
    self.enrefs = enrefs
87

88

89
class Conversation(object):
90
  """Represents a Contrack conversation."""
91

92
  def __init__(self, conversation_id):
93
    self.conversation_id = conversation_id
94
    self.turns = []
95

96
  def append(self, turn):
97
    self.turns.append(turn)
98

99

100
all_speakers = set()
101

102

103
def str_to_id(name):
104
  if name in KNOWN_CHARACTERS:
105
    return name.split(' ')[0].lower()
106
  elif name in ENTITY_GROUPS:
107
    return ENTITY_GROUPS[name]
108
  else:
109
    return 'other'
110

111

112
def build_env_turn(conversation_id):
113
  """Creates a first turn describing the environment."""
114
  entities = [n.split(' ')[0].lower() for n in KNOWN_CHARACTERS]
115
  entities += ENTITY_GROUPS.values()
116

117
  enrefs_decls = []
118
  for i, name in enumerate(entities):
119
    enrefs_decls.append(f'[{name} {GENDER[name]} {i}-{i}]')
120
  utterance_str = ' '.join(entities)
121
  enrefs_str = ''.join(enrefs_decls)
122

123
  turn = f'conv:{conversation_id}| {utterance_str}| {enrefs_str}\n'
124
  return turn
125

126

127
def read_season(num):
128
  """Reads a Season from file and returns it as a list of conversations."""
129
  global all_speakers
130
  filepath = os.path.join(FLAGS.input_dir, f'friends_season_0{num}.json')
131

132
  with tf.io.gfile.GFile(filepath, 'r') as json_file:
133
    season_json = json.load(json_file)
134

135
  conversations = []
136
  for episode_json in season_json['episodes']:
137
    for scene_json in episode_json['scenes']:
138
      conversation = Conversation(scene_json['scene_id'])
139
      participants = []
140

141
      for utterance_json in scene_json['utterances']:
142
        utterance_id = utterance_json['utterance_id']
143

144
        # Extract speaker
145
        speakers = utterance_json['speakers']
146
        all_speakers.update(speakers)
147
        if not speakers:
148
          logging.info('no speaker for turn %s', utterance_id)
149
          continue
150
        if len(speakers) > 1:
151
          logging.info('multiple speakers for turn %s', utterance_id)
152
        speaker = str_to_id(speakers[0])
153
        if speaker not in participants:
154
          participants.append(speaker)
155

156
        for sentence_index, tokens in enumerate(utterance_json['tokens']):
157
          # Extract tokens
158
          if not tokens or tokens[0] == '_':
159
            logging.info('empty turn %s', utterance_id)
160
            continue
161

162
          # Extract enrefs
163
          entities_json = utterance_json['character_entities'][sentence_index]
164
          enrefs = []
165
          for entity_json in entities_json:
166
            if len(entity_json) < 3:
167
              logging.info('%s: cannot parse entity: %s', utterance_id,
168
                           entity_json)
169
              continue
170
            # logging.info(entity_json)
171
            start_index = entity_json[0]
172
            end_index = entity_json[1] - 1
173
            entities = set([str_to_id(e) for e in entity_json[2:]])
174
            if not entities:
175
              logging.info('empty entities list in %s', utterance_id)
176
              continue
177
            if len(entities) == 1:
178
              entity_name = next(iter(entities))
179
              gender = GENDER[entity_name]
180
              enref = f'[{entity_name} {gender} {start_index}-{end_index}]'
181
            else:
182
              enref_name = tokens[start_index]
183
              member_decl = ':'.join(entities)
184
              enref = f'[{enref_name} g({member_decl}) {start_index}-{end_index}]'
185
            enrefs.append(enref)
186

187
          conversation.append(Turn(speaker, ' '.join(tokens), enrefs))
188
      conversations.append(conversation)
189
  return conversations
190

191

192
def convert_data():
193
  """Converts the data from all seasons."""
194
  for season_num in [1, 2, 3, 4]:
195
    conversations = read_season(season_num)
196
    trg_path = os.path.join(FLAGS.output_dir,
197
                            f'char_ident_trg.txt-0000{season_num - 1}-of-00004')
198
    tst_path = os.path.join(FLAGS.output_dir,
199
                            f'char_ident_tst.txt-0000{season_num - 1}-of-00004')
200

201
    with tf.io.gfile.GFile(trg_path, 'w') as trg_file:
202
      with tf.io.gfile.GFile(tst_path, 'w') as tst_file:
203
        for conversation in conversations:
204
          c_id = conversation.conversation_id
205
          episode_nr = int(re.search(r's.._e(..)_c..', c_id).group(1))
206
          file = trg_file if episode_nr <= 21 else tst_file
207

208
          file.write(build_env_turn(conversation.conversation_id))
209
          for turn in conversation.turns:
210
            line = turn.speaker + '| '
211
            line += turn.msg + '| '
212
            line += ''.join(turn.enrefs)
213
            line += '\n'
214
            file.write(line)
215
          file.write('\n')
216

217
  global all_speakers
218
  for s in all_speakers:
219
    logging.info(s)
220

221

222
def main(argv):
223
  del argv
224
  convert_data()
225

226
if __name__ == '__main__':
227
  app.run(main)
228

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.