20
from packaging import version
22
from quanto import absmax_scale, qint8, quantize_activation, quantize_weight
25
def torch_min_version(v):
26
def torch_min_version_decorator(test):
27
@functools.wraps(test)
28
def test_wrapper(*args, **kwargs):
29
if version.parse(torch.__version__) < version.parse(v):
30
pytest.skip(f"Requires pytorch >= {v}")
35
return torch_min_version_decorator
41
a_index = a.index if a.index is not None else 0
42
b_index = b.index if b.index is not None else 0
43
return a_index == b_index
46
def random_tensor(shape, dtype=torch.float32):
48
return torch.rand(shape, dtype=dtype) * 2 - 1
51
def random_qactivation(shape, qtype=qint8, dtype=torch.float32):
52
t = random_tensor(shape, dtype)
53
scale = absmax_scale(t, qtype=qtype)
54
return quantize_activation(t, qtype=qtype, scale=scale)
57
def random_qweight(shape, qtype, dtype=torch.float32, axis=0, group_size=None):
58
t = random_tensor(shape, dtype)
59
return quantize_weight(t, qtype=qtype, axis=axis, group_size=group_size)
62
def assert_similar(a, b, atol=None, rtol=None):
63
"""Verify that the cosine similarity of the two inputs is close to 1.0 everywhere"""
64
assert a.dtype == b.dtype
65
assert a.shape == b.shape
68
atol = torch.finfo(a.dtype).resolution
72
rtol = {torch.float32: 1e-5, torch.float16: 1e-3, torch.bfloat16: 1e-1}[a.dtype]
73
sim = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0)
74
if not torch.allclose(sim, torch.tensor(1.0, dtype=sim.dtype), atol=atol, rtol=rtol):
75
max_deviation = torch.min(sim)
76
raise ValueError(f"Alignment {max_deviation:.8f} deviates too much from 1.0 with atol={atol}, rtol={rtol}")
79
def get_device_memory(device):
81
if device.type == "cuda":
82
torch.cuda.empty_cache()
83
return torch.cuda.memory_allocated()
84
elif device.type == "mps":
85
torch.mps.empty_cache()
86
return torch.mps.current_allocated_memory()