lmops

Форк
0
/
inference_openai.py 
108 строк · 3.7 Кб
1
import asyncio
2
from semantic_parsing_with_constrained_lm.src.semantic_parsing_with_constrained_lm.lm_openai_gpt3 import (
3
    GPT3Client,
4
    IncrementalOpenAIGPT3,
5
)
6
import json
7
import more_itertools
8
import tqdm
9
import hydra
10
import os
11
from DPR.dpr.utils.tasks import task_map
12
from src.utils.metric import compute_scores
13
import logging
14

15
logger = logging.getLogger(__name__)
16

17

18
@hydra.main(config_path="configs", config_name="inference_openai")
19
def main(cfg):
20
    print(cfg)
21
    client = GPT3Client(api_key=os.environ["OPENAI_TOKEN"])
22
    lm = IncrementalOpenAIGPT3(
23
        client=client, engine=cfg.engine, cache_dir=cfg.cache_dir
24
    )
25

26
    async def get_pred_completion(entry_list, cfg):
27
        """
28
        for text completion
29
        """
30
        prompt = [x["enc_text"][0].strip() for x in entry_list]
31
        args = {
32
            "prompt": prompt,
33
            "max_tokens": min(cfg.generate_max_len, cfg.max_length - cfg.n_tokens),
34
            "stop": ["\n"],
35
            "echo": False,
36
            "logprobs": 1,
37
        }
38
        results = (
39
            await client.completions_rate_limited(cfg.engine, args)  # type: ignore
40
        ).json()
41
        for i, x in enumerate(entry_list):
42
            x["pred"] = results["choices"][i]["text"]
43
        return entry_list
44

45
    async def get_pred_choice(entry_list):
46
        """
47
        for multiple choice
48
        """
49
        assert len(entry_list) == 1  # because bsz=1
50
        entry = entry_list[0]
51
        res_list = []
52
        for i in range(len(entry["enc_text"])):
53
            enc_text = entry["enc_text"][i].strip()
54
            enc_answer = entry["enc_answer"][i].strip()
55

56
            prefix_tokens = lm.tokenizer.encode(enc_text)
57
            tokenized_labels = lm.tokenizer.encode(enc_answer)
58

59
            summed_logprob = await lm.logprob_of_completion(
60
                prefix_tokens, tokenized_labels
61
            )  # likelihood
62
            nll = -summed_logprob  # negative likelihood
63
            loss = nll / len(tokenized_labels)  # average by length of tokenized_labels
64
            res_list.append(loss)
65
        sum_loss = sum(res_list)
66
        normed_loss = [loss / sum_loss for loss in res_list]
67
        entry["pred"] = normed_loss.index(min(normed_loss))
68
        return [entry]
69

70
    async def run(data_list):
71
        task_list = []
72
        for i, prompt in enumerate(more_itertools.chunked(data_list, cfg.batch_size)):
73
            if len(data_list[0]["enc_text"]) > 1:  # multiple choice
74
                assert cfg.batch_size == 1
75
                task = asyncio.create_task(get_pred_choice(prompt))
76
            else:  # text completion
77
                task = asyncio.create_task(get_pred_completion(prompt, cfg))
78
            task_list.append(task)
79
        responses = [
80
            await f
81
            for f in tqdm.tqdm(asyncio.as_completed(task_list), total=len(task_list))
82
        ]
83
        return responses
84

85
    def run_main(cfg):
86
        with open(cfg.prompt_file) as f:
87
            data_list = json.load(f)
88
        res = asyncio.run(run(data_list))
89
        res = list(more_itertools.collapse(res, levels=1))
90
        os.makedirs(os.path.dirname(cfg.output_file), exist_ok=True)
91
        os.makedirs(os.path.dirname(cfg.res_file), exist_ok=True)
92

93
        with open(cfg.output_file, "w") as f:
94
            json.dump(res, f)
95
        task = task_map.cls_dic[cfg.task_name]()
96
        scores = compute_scores(task.metric, res)
97
        method = "UPRISE" if int(res[0]["n_prompts"]) > 0 else "0-SHOT"
98
        logger.info("method: %s", method)
99
        logger.info("scores: %s", str(scores))
100
        with open(cfg.res_file, "a") as f:
101
            f.write(
102
                f"LLM: {str(cfg.engine)}; task_name: {str(cfg.task_name)}; Method: {method}; scores: {str(scores)}\n"
103
            )
104
    run_main(cfg)
105

106

107
if __name__ == "__main__":
108
    main()
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.