2
from semantic_parsing_with_constrained_lm.src.semantic_parsing_with_constrained_lm.lm_openai_gpt3 import (
11
from DPR.dpr.utils.tasks import task_map
12
from src.utils.metric import compute_scores
15
logger = logging.getLogger(__name__)
18
@hydra.main(config_path="configs", config_name="inference_openai")
21
client = GPT3Client(api_key=os.environ["OPENAI_TOKEN"])
22
lm = IncrementalOpenAIGPT3(
23
client=client, engine=cfg.engine, cache_dir=cfg.cache_dir
26
async def get_pred_completion(entry_list, cfg):
30
prompt = [x["enc_text"][0].strip() for x in entry_list]
33
"max_tokens": min(cfg.generate_max_len, cfg.max_length - cfg.n_tokens),
39
await client.completions_rate_limited(cfg.engine, args)
41
for i, x in enumerate(entry_list):
42
x["pred"] = results["choices"][i]["text"]
45
async def get_pred_choice(entry_list):
49
assert len(entry_list) == 1
52
for i in range(len(entry["enc_text"])):
53
enc_text = entry["enc_text"][i].strip()
54
enc_answer = entry["enc_answer"][i].strip()
56
prefix_tokens = lm.tokenizer.encode(enc_text)
57
tokenized_labels = lm.tokenizer.encode(enc_answer)
59
summed_logprob = await lm.logprob_of_completion(
60
prefix_tokens, tokenized_labels
63
loss = nll / len(tokenized_labels)
65
sum_loss = sum(res_list)
66
normed_loss = [loss / sum_loss for loss in res_list]
67
entry["pred"] = normed_loss.index(min(normed_loss))
70
async def run(data_list):
72
for i, prompt in enumerate(more_itertools.chunked(data_list, cfg.batch_size)):
73
if len(data_list[0]["enc_text"]) > 1:
74
assert cfg.batch_size == 1
75
task = asyncio.create_task(get_pred_choice(prompt))
77
task = asyncio.create_task(get_pred_completion(prompt, cfg))
78
task_list.append(task)
81
for f in tqdm.tqdm(asyncio.as_completed(task_list), total=len(task_list))
86
with open(cfg.prompt_file) as f:
87
data_list = json.load(f)
88
res = asyncio.run(run(data_list))
89
res = list(more_itertools.collapse(res, levels=1))
90
os.makedirs(os.path.dirname(cfg.output_file), exist_ok=True)
91
os.makedirs(os.path.dirname(cfg.res_file), exist_ok=True)
93
with open(cfg.output_file, "w") as f:
95
task = task_map.cls_dic[cfg.task_name]()
96
scores = compute_scores(task.metric, res)
97
method = "UPRISE" if int(res[0]["n_prompts"]) > 0 else "0-SHOT"
98
logger.info("method: %s", method)
99
logger.info("scores: %s", str(scores))
100
with open(cfg.res_file, "a") as f:
102
f"LLM: {str(cfg.engine)}; task_name: {str(cfg.task_name)}; Method: {method}; scores: {str(scores)}\n"
107
if __name__ == "__main__":