quanto
1# Copyright 2024 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest16import torch17from helpers import random_tensor18
19
20@pytest.mark.parametrize("input_shape", [[10, 32], [32, 32]])21@pytest.mark.parametrize("output_features", [48, 64])22@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])23def test_dqmm(input_shape, output_features, dtype, device):24input = random_tensor(input_shape, dtype=dtype).to(device)25other = torch.randint(-127, 127, (input_shape[-1], output_features), dtype=torch.int8).to(device)26other_scale = random_tensor((output_features,), dtype=dtype).to(device)27output = torch.ops.quanto.dqmm(input, other, other_scale)28expected = torch.ops.aten.mm(input, other * other_scale)29assert torch.equal(expected, output)30