quanto

Форк
0
107 строк · 3.6 Кб
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import gc
16
import time
17

18
import numpy as np
19
import torch
20
from tqdm.auto import tqdm
21
from transformers import GenerationConfig
22

23

24
def latency(model, tokenizer, device, batch_size=1, prompt_length=512, nb_tokens=512, iterations=10):
25
    def synchronize(device):
26
        if device.type == "cuda":
27
            torch.cuda.synchronize()
28
        elif device.type == "mps":
29
            torch.mps.synchronize()
30
        else:
31
            torch.cpu.synchronize()
32

33
    def timing_event(device):
34
        if device.type == "cuda":
35
            return torch.cuda.Event(enable_timing=True)
36
        elif device.type == "mps":
37
            return torch.mps.Event(enable_timing=True)
38

39
        class CPUEvent:
40
            def __init__(self):
41
                self.time = None
42

43
            def record(self):
44
                self.time = time.time()
45

46
            def elapsed_time(self, other):
47
                assert self.time is not None
48
                assert other.time is not None
49
                return (other.time - self.time) * 1000
50

51
        return CPUEvent()
52

53
    generation_config = GenerationConfig(
54
        max_new_tokens=nb_tokens,
55
        min_new_tokens=nb_tokens,
56
        use_cache=True,
57
        pad_token_id=tokenizer.pad_token_id,
58
        num_beams=1,
59
        do_sample=False,
60
        eos_token_id=None,  # This is required for min_new_tokens to actually have an effect.
61
    )
62
    if getattr(model, "generation_config", None) is not None:
63
        model.generation_config.eos_token_id = None  # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
64

65
    synchronize(device)
66
    if device.type == "cuda":
67
        torch.cuda.reset_peak_memory_stats()
68

69
    memory = get_device_memory(device)
70
    if memory is not None:
71
        print(f"Device memory: {memory / (2 ** 30):.4f} GB")
72

73
    latencies = []
74
    input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
75
    masks = torch.ones(batch_size, prompt_length, dtype=torch.int32).to(device)
76

77
    for _ in tqdm(range(iterations)):
78
        start_event = timing_event(device)
79
        end_event = timing_event(device)
80
        synchronize(device)
81
        start_event.record()
82

83
        _ = model.generate(input_ids, attention_mask=masks, generation_config=generation_config)
84
        end_event.record()
85
        synchronize(device)
86

87
        latency_ms = start_event.elapsed_time(end_event)
88
        latencies.append(latency_ms)
89

90
    if device.type == "cuda":
91
        peak_memory = torch.cuda.max_memory_allocated()
92
        print(f"Peak memory during benchmark: {peak_memory / (2 ** 30):.4f} GB")
93

94
    mean_latency = np.mean(latencies) / generation_config.min_new_tokens
95
    print(f"Average latency per token: {mean_latency} ms")
96
    return mean_latency
97

98

99
def get_device_memory(device):
100
    gc.collect()
101
    if device.type == "cuda":
102
        torch.cuda.empty_cache()
103
        return torch.cuda.memory_allocated()
104
    elif device.type == "mps":
105
        torch.mps.empty_cache()
106
        return torch.mps.current_allocated_memory()
107
    return None
108

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.