quanto
107 строк · 3.6 Кб
1# Copyright 2024 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import gc
16import time
17
18import numpy as np
19import torch
20from tqdm.auto import tqdm
21from transformers import GenerationConfig
22
23
24def latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=10):
25def synchronize(device):
26if device.type == "cuda":
27torch.cuda.synchronize()
28elif device.type == "mps":
29torch.mps.synchronize()
30else:
31torch.cpu.synchronize()
32
33def timing_event(device):
34if device.type == "cuda":
35return torch.cuda.Event(enable_timing=True)
36elif device.type == "mps":
37return torch.mps.Event(enable_timing=True)
38
39class CPUEvent:
40def __init__(self):
41self.time = None
42
43def record(self):
44self.time = time.time()
45
46def elapsed_time(self, other):
47assert self.time is not None
48assert other.time is not None
49return (other.time - self.time) * 1000
50
51return CPUEvent()
52
53generation_config = GenerationConfig(
54max_new_tokens=nb_tokens,
55min_new_tokens=nb_tokens,
56use_cache=True,
57pad_token_id=tokenizer.pad_token_id,
58num_beams=1,
59do_sample=False,
60eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
61)
62if getattr(model, "generation_config", None) is not None:
63model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
64
65synchronize(device)
66if device.type == "cuda":
67torch.cuda.reset_peak_memory_stats()
68
69memory = get_device_memory(device)
70if memory is not None:
71print(f"Device memory: {memory / (2 ** 30):.4f} GB")
72
73latencies = []
74input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
75masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)
76
77for _ in tqdm(range(iterations)):
78start_event = timing_event(device)
79end_event = timing_event(device)
80synchronize(device)
81start_event.record()
82
83_ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
84end_event.record()
85synchronize(device)
86
87latency_ms = start_event.elapsed_time(end_event)
88latencies.append(latency_ms)
89
90if device.type == "cuda":
91peak_memory = torch.cuda.max_memory_allocated()
92print(f"Peak memory during benchmark: {peak_memory / (2 ** 30):.4f} GB")
93
94mean_latency = np.mean(latencies) / generation_config.min_new_tokens
95print(f"Average latency per token: {mean_latency} ms")
96return mean_latency
97
98
99def get_device_memory(device):
100gc.collect()
101if device.type == "cuda":
102torch.cuda.empty_cache()
103return torch.cuda.memory_allocated()
104elif device.type == "mps":
105torch.mps.empty_cache()
106return torch.mps.current_allocated_memory()
107return None
108