GLiNER

Форк
0
/
process_pilener.py 
67 строк · 2.5 Кб
1
import json
2
import re
3
import ast
4
from tqdm import tqdm
5

6
def load_data(filepath):
7
    """Loads data from a JSON file."""
8
    with open(filepath, 'r') as f:
9
        data = json.load(f)
10
    return data
11

12
def tokenize_text(text):
13
    """Tokenizes the input text into a list of tokens."""
14
    return re.findall(r'\w+(?:[-_]\w+)*|\S', text)
15

16
def extract_entity_spans(entry):
17
    """Extracts entity spans from an entry."""
18
    len_start = len("What describes ")
19
    len_end = len(" in the text?")
20
    entity_types, entity_texts, negative = [], [], []
21

22
    for c in entry['conversations']:
23
        if c['from'] == 'human' and c['value'].startswith('Text: '):
24
            text = c['value'][len('Text: '):]
25
            tokenized_text = tokenize_text(text)
26
        elif c['from'] == 'human' and c['value'].startswith('What describes '):
27
            entity_type = c['value'][len_start:-len_end]
28
            entity_types.append(entity_type)
29
        elif c['from'] == 'gpt' and c['value'].startswith('['):
30
            if c['value'] == '[]':
31
                negative.append(entity_types.pop())
32
                continue
33
            texts_ents = ast.literal_eval(c['value'])
34
            entity_texts.extend(texts_ents)
35
            num_repeat = len(texts_ents) - 1
36
            entity_types.extend([entity_types[-1]] * num_repeat)
37

38
    entity_spans = []
39
    for j, entity_text in enumerate(entity_texts):
40
        entity_tokens = tokenize_text(entity_text)
41
        matches = []
42
        for i in range(len(tokenized_text) - len(entity_tokens) + 1):
43
            if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
44
                matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
45
        if matches:
46
            entity_spans.extend(matches)
47

48
    return {"tokenized_text": tokenized_text, "ner": entity_spans, "negative": negative}
49

50
def process_data(data):
51
    """Processes a list of data entries to extract entity spans."""
52
    all_data = [extract_entity_spans(entry) for entry in tqdm(data)]
53
    return all_data
54

55
def save_data_to_file(data, filepath):
56
    """Saves the processed data to a JSON file."""
57
    with open(filepath, 'w') as f:
58
        json.dump(data, f)
59

60
if __name__ == "__main__":
61
    # download the pile-ner data: "wget https://huggingface.co/datasets/Universal-NER/Pile-NER-type/blob/main/train.json"
62
    path_pile_ner = 'train.json'
63
    data = load_data(path_pile_ner)
64
    processed_data = process_data(data)
65
    save_data_to_file(processed_data, 'pilener_train.json')
66

67
    print("dataset size:", len(processed_data))

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.