pytorch

Форк
0
/
per_sample_grads.py 
97 строк · 2.6 Кб
1
import time
2

3
import torchvision.models as models
4
from opacus import PrivacyEngine
5
from opacus.utils.module_modification import convert_batchnorm_modules
6

7
import torch
8
import torch.nn as nn
9
from functorch import grad, make_functional, vmap
10

11

12
device = "cuda"
13
batch_size = 128
14
torch.manual_seed(0)
15

16
model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10))
17
model_functorch = model_functorch.to(device)
18
criterion = nn.CrossEntropyLoss()
19

20
images = torch.randn(batch_size, 3, 32, 32, device=device)
21
targets = torch.randint(0, 10, (batch_size,), device=device)
22
func_model, weights = make_functional(model_functorch)
23

24

25
def compute_loss(weights, image, target):
26
    images = image.unsqueeze(0)
27
    targets = target.unsqueeze(0)
28
    output = func_model(weights, images)
29
    loss = criterion(output, targets)
30
    return loss
31

32

33
def functorch_per_sample_grad():
34
    compute_grad = grad(compute_loss)
35
    compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))
36

37
    start = time.time()
38
    result = compute_per_sample_grad(weights, images, targets)
39
    torch.cuda.synchronize()
40
    end = time.time()
41

42
    return result, end - start  # end - start in seconds
43

44

45
torch.manual_seed(0)
46
model_opacus = convert_batchnorm_modules(models.resnet18(num_classes=10))
47
model_opacus = model_opacus.to(device)
48
criterion = nn.CrossEntropyLoss()
49
for p_f, p_o in zip(model_functorch.parameters(), model_opacus.parameters()):
50
    assert torch.allclose(p_f, p_o)  # Sanity check
51

52
privacy_engine = PrivacyEngine(
53
    model_opacus,
54
    sample_rate=0.01,
55
    alphas=[10, 100],
56
    noise_multiplier=1,
57
    max_grad_norm=10000.0,
58
)
59

60

61
def opacus_per_sample_grad():
62
    start = time.time()
63
    output = model_opacus(images)
64
    loss = criterion(output, targets)
65
    loss.backward()
66
    torch.cuda.synchronize()
67
    end = time.time()
68
    expected = [p.grad_sample for p in model_opacus.parameters()]
69
    for p in model_opacus.parameters():
70
        delattr(p, "grad_sample")
71
        p.grad = None
72
    return expected, end - start
73

74

75
for _ in range(5):
76
    _, seconds = functorch_per_sample_grad()
77
    print(seconds)
78

79
result, seconds = functorch_per_sample_grad()
80
print(seconds)
81

82
for _ in range(5):
83
    _, seconds = opacus_per_sample_grad()
84
    print(seconds)
85

86
expected, seconds = opacus_per_sample_grad()
87
print(seconds)
88

89
result = [r.detach() for r in result]
90
print(len(result))
91

92
# TODO: The following shows that the per-sample-grads computed are different.
93
# This concerns me a little; we should compare to a source of truth.
94
# for i, (r, e) in enumerate(list(zip(result, expected))[::-1]):
95
#     if torch.allclose(r, e, rtol=1e-5):
96
#         continue
97
#     print(-(i+1), ((r - e)/(e + 0.000001)).abs().max())
98

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.