6
def load_data(filepath):
7
"""Loads data from a JSON file."""
8
with open(filepath, 'r') as f:
12
def tokenize_text(text):
13
"""Tokenizes the input text into a list of tokens."""
14
return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
16
def extract_entity_spans(entry):
17
"""Extracts entity spans from an entry."""
18
len_start = len("What describes ")
19
len_end = len(" in the text?")
20
entity_types, entity_texts, negative = [], [], []
22
for c in entry['conversations']:
23
if c['from'] == 'human' and c['value'].startswith('Text: '):
24
text = c['value'][len('Text: '):]
25
tokenized_text = tokenize_text(text)
26
elif c['from'] == 'human' and c['value'].startswith('What describes '):
27
entity_type = c['value'][len_start:-len_end]
28
entity_types.append(entity_type)
29
elif c['from'] == 'gpt' and c['value'].startswith('['):
30
if c['value'] == '[]':
31
negative.append(entity_types.pop())
33
texts_ents = ast.literal_eval(c['value'])
34
entity_texts.extend(texts_ents)
35
num_repeat = len(texts_ents) - 1
36
entity_types.extend([entity_types[-1]] * num_repeat)
39
for j, entity_text in enumerate(entity_texts):
40
entity_tokens = tokenize_text(entity_text)
42
for i in range(len(tokenized_text) - len(entity_tokens) + 1):
43
if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
44
matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
46
entity_spans.extend(matches)
48
return {"tokenized_text": tokenized_text, "ner": entity_spans, "negative": negative}
50
def process_data(data):
51
"""Processes a list of data entries to extract entity spans."""
52
all_data = [extract_entity_spans(entry) for entry in tqdm(data)]
55
def save_data_to_file(data, filepath):
56
"""Saves the processed data to a JSON file."""
57
with open(filepath, 'w') as f:
60
if __name__ == "__main__":
62
path_pile_ner = 'train.json'
63
data = load_data(path_pile_ner)
64
processed_data = process_data(data)
65
save_data_to_file(processed_data, 'pilener_train.json')
67
print("dataset size:", len(processed_data))