google-research
227 строк · 6.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Converts character mining json data to Contrack text format."""
17
18import json19import os20import re21from typing import List, Sequence22
23from absl import app24from absl import flags25from absl import logging26
27import tensorflow as tf28
29flags.DEFINE_string(30'input_dir', '',31'Path to the directory containing the input data json files.')32
33flags.DEFINE_string(34'output_dir', '',35'Path to the directory where the conversations are written.')36
37FLAGS = flags.FLAGS38
39KNOWN_CHARACTERS = frozenset([40'Ross Geller',41'Rachel Green',42'Chandler Bing',43'Monica Geller',44'Joey Tribbiani',45'Phoebe Buffay',46'Emily Waltham',47'Richard Burke',48'Carol Willick',49'Ben Geller',50'Peter Becker',51'Judy Geller',52'Barry Farber',53'Jack Geller',54'Kate Miller'55])56
57ENTITY_GROUPS = {'#GENERAL#': 'general', '#OTHER#': 'other'}58
59GENDER = {60'ross': 'm',61'rachel': 'f',62'chandler': 'm',63'monica': 'f',64'joey': 'm',65'phoebe': 'f',66'emily': 'f',67'richard': 'm',68'carol': 'f',69'ben': 'm',70'peter': 'm',71'judy': 'f',72'barry': 'm',73'jack': 'm',74'kate': 'f',75'general': 'u',76'other': 'u'77}
78
79
80class Turn(object):81"""Represents a turn in a Contrack conversation."""82
83def __init__(self, speaker, msg, enrefs):84self.speaker = speaker85self.msg = msg86self.enrefs = enrefs87
88
89class Conversation(object):90"""Represents a Contrack conversation."""91
92def __init__(self, conversation_id):93self.conversation_id = conversation_id94self.turns = []95
96def append(self, turn):97self.turns.append(turn)98
99
100all_speakers = set()101
102
103def str_to_id(name):104if name in KNOWN_CHARACTERS:105return name.split(' ')[0].lower()106elif name in ENTITY_GROUPS:107return ENTITY_GROUPS[name]108else:109return 'other'110
111
112def build_env_turn(conversation_id):113"""Creates a first turn describing the environment."""114entities = [n.split(' ')[0].lower() for n in KNOWN_CHARACTERS]115entities += ENTITY_GROUPS.values()116
117enrefs_decls = []118for i, name in enumerate(entities):119enrefs_decls.append(f'[{name} {GENDER[name]} {i}-{i}]')120utterance_str = ' '.join(entities)121enrefs_str = ''.join(enrefs_decls)122
123turn = f'conv:{conversation_id}| {utterance_str}| {enrefs_str}\n'124return turn125
126
127def read_season(num):128"""Reads a Season from file and returns it as a list of conversations."""129global all_speakers130filepath = os.path.join(FLAGS.input_dir, f'friends_season_0{num}.json')131
132with tf.io.gfile.GFile(filepath, 'r') as json_file:133season_json = json.load(json_file)134
135conversations = []136for episode_json in season_json['episodes']:137for scene_json in episode_json['scenes']:138conversation = Conversation(scene_json['scene_id'])139participants = []140
141for utterance_json in scene_json['utterances']:142utterance_id = utterance_json['utterance_id']143
144# Extract speaker145speakers = utterance_json['speakers']146all_speakers.update(speakers)147if not speakers:148logging.info('no speaker for turn %s', utterance_id)149continue150if len(speakers) > 1:151logging.info('multiple speakers for turn %s', utterance_id)152speaker = str_to_id(speakers[0])153if speaker not in participants:154participants.append(speaker)155
156for sentence_index, tokens in enumerate(utterance_json['tokens']):157# Extract tokens158if not tokens or tokens[0] == '_':159logging.info('empty turn %s', utterance_id)160continue161
162# Extract enrefs163entities_json = utterance_json['character_entities'][sentence_index]164enrefs = []165for entity_json in entities_json:166if len(entity_json) < 3:167logging.info('%s: cannot parse entity: %s', utterance_id,168entity_json)169continue170# logging.info(entity_json)171start_index = entity_json[0]172end_index = entity_json[1] - 1173entities = set([str_to_id(e) for e in entity_json[2:]])174if not entities:175logging.info('empty entities list in %s', utterance_id)176continue177if len(entities) == 1:178entity_name = next(iter(entities))179gender = GENDER[entity_name]180enref = f'[{entity_name} {gender} {start_index}-{end_index}]'181else:182enref_name = tokens[start_index]183member_decl = ':'.join(entities)184enref = f'[{enref_name} g({member_decl}) {start_index}-{end_index}]'185enrefs.append(enref)186
187conversation.append(Turn(speaker, ' '.join(tokens), enrefs))188conversations.append(conversation)189return conversations190
191
192def convert_data():193"""Converts the data from all seasons."""194for season_num in [1, 2, 3, 4]:195conversations = read_season(season_num)196trg_path = os.path.join(FLAGS.output_dir,197f'char_ident_trg.txt-0000{season_num - 1}-of-00004')198tst_path = os.path.join(FLAGS.output_dir,199f'char_ident_tst.txt-0000{season_num - 1}-of-00004')200
201with tf.io.gfile.GFile(trg_path, 'w') as trg_file:202with tf.io.gfile.GFile(tst_path, 'w') as tst_file:203for conversation in conversations:204c_id = conversation.conversation_id205episode_nr = int(re.search(r's.._e(..)_c..', c_id).group(1))206file = trg_file if episode_nr <= 21 else tst_file207
208file.write(build_env_turn(conversation.conversation_id))209for turn in conversation.turns:210line = turn.speaker + '| '211line += turn.msg + '| '212line += ''.join(turn.enrefs)213line += '\n'214file.write(line)215file.write('\n')216
217global all_speakers218for s in all_speakers:219logging.info(s)220
221
222def main(argv):223del argv224convert_data()225
226if __name__ == '__main__':227app.run(main)228