vision

Форк
0
/
test_ops.py 
1977 строк · 79.7 Кб
1
import math
2
import os
3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
7

8
import numpy as np
9
import pytest
10
import torch
11
import torch.fx
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
15
from PIL import Image
16
from torch import nn, Tensor
17
from torch._dynamo.utils import is_compile_supported
18
from torch.autograd import gradcheck
19
from torch.nn.modules.utils import _pair
20
from torchvision import models, ops
21
from torchvision.models.feature_extraction import get_graph_node_names
22

23

24
OPTESTS = [
25
    "test_schema",
26
    "test_autograd_registration",
27
    "test_faketensor",
28
    "test_aot_dispatch_dynamic",
29
]
30

31

32
# Context manager for setting deterministic flag and automatically
33
# resetting it to its original value
34
class DeterministicGuard:
35
    def __init__(self, deterministic, *, warn_only=False):
36
        self.deterministic = deterministic
37
        self.warn_only = warn_only
38

39
    def __enter__(self):
40
        self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
41
        self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
42
        torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
43

44
    def __exit__(self, exception_type, exception_value, traceback):
45
        torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)
46

47

48
class RoIOpTesterModuleWrapper(nn.Module):
49
    def __init__(self, obj):
50
        super().__init__()
51
        self.layer = obj
52
        self.n_inputs = 2
53

54
    def forward(self, a, b):
55
        self.layer(a, b)
56

57

58
class MultiScaleRoIAlignModuleWrapper(nn.Module):
59
    def __init__(self, obj):
60
        super().__init__()
61
        self.layer = obj
62
        self.n_inputs = 3
63

64
    def forward(self, a, b, c):
65
        self.layer(a, b, c)
66

67

68
class DeformConvModuleWrapper(nn.Module):
69
    def __init__(self, obj):
70
        super().__init__()
71
        self.layer = obj
72
        self.n_inputs = 3
73

74
    def forward(self, a, b, c):
75
        self.layer(a, b, c)
76

77

78
class StochasticDepthWrapper(nn.Module):
79
    def __init__(self, obj):
80
        super().__init__()
81
        self.layer = obj
82
        self.n_inputs = 1
83

84
    def forward(self, a):
85
        self.layer(a)
86

87

88
class DropBlockWrapper(nn.Module):
89
    def __init__(self, obj):
90
        super().__init__()
91
        self.layer = obj
92
        self.n_inputs = 1
93

94
    def forward(self, a):
95
        self.layer(a)
96

97

98
class PoolWrapper(nn.Module):
99
    def __init__(self, pool: nn.Module):
100
        super().__init__()
101
        self.pool = pool
102

103
    def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
104
        return self.pool(imgs, boxes)
105

106

107
class RoIOpTester(ABC):
108
    dtype = torch.float64
109
    mps_dtype = torch.float32
110
    mps_backward_atol = 2e-2
111

112
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
113
    @pytest.mark.parametrize("contiguous", (True, False))
114
    @pytest.mark.parametrize(
115
        "x_dtype",
116
        (
117
            torch.float16,
118
            torch.float32,
119
            torch.float64,
120
        ),
121
        ids=str,
122
    )
123
    def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
124
        if device == "mps" and x_dtype is torch.float64:
125
            pytest.skip("MPS does not support float64")
126

127
        rois_dtype = x_dtype if rois_dtype is None else rois_dtype
128

129
        tol = 1e-5
130
        if x_dtype is torch.half:
131
            if device == "mps":
132
                tol = 5e-3
133
            else:
134
                tol = 4e-3
135
        elif x_dtype == torch.bfloat16:
136
            tol = 5e-3
137

138
        pool_size = 5
139
        # n_channels % (pool_size ** 2) == 0 required for PS operations.
140
        n_channels = 2 * (pool_size**2)
141
        x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
142
        if not contiguous:
143
            x = x.permute(0, 1, 3, 2)
144
        rois = torch.tensor(
145
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
146
            dtype=rois_dtype,
147
            device=device,
148
        )
149

150
        pool_h, pool_w = pool_size, pool_size
151
        with DeterministicGuard(deterministic):
152
            y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
153
        # the following should be true whether we're running an autocast test or not.
154
        assert y.dtype == x.dtype
155
        gt_y = self.expected_fn(
156
            x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
157
        )
158

159
        torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
160

161
    @pytest.mark.parametrize("device", cpu_and_cuda())
162
    def test_is_leaf_node(self, device):
163
        op_obj = self.make_obj(wrap=True).to(device=device)
164
        graph_node_names = get_graph_node_names(op_obj)
165

166
        assert len(graph_node_names) == 2
167
        assert len(graph_node_names[0]) == len(graph_node_names[1])
168
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
169

170
    @pytest.mark.parametrize("device", cpu_and_cuda())
171
    def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
172
        op_obj = self.make_obj().to(device=device)
173
        graph_module = torch.fx.symbolic_trace(op_obj)
174
        pool_size = 5
175
        n_channels = 2 * (pool_size**2)
176
        x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
177
        rois = torch.tensor(
178
            [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],  # format is (xyxy)
179
            dtype=rois_dtype,
180
            device=device,
181
        )
182
        output_gt = op_obj(x, rois)
183
        assert output_gt.dtype == x.dtype
184
        output_fx = graph_module(x, rois)
185
        assert output_fx.dtype == x.dtype
186
        tol = 1e-5
187
        torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
188

189
    @pytest.mark.parametrize("seed", range(10))
190
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
191
    @pytest.mark.parametrize("contiguous", (True, False))
192
    def test_backward(self, seed, device, contiguous, deterministic=False):
193
        atol = self.mps_backward_atol if device == "mps" else 1e-05
194
        dtype = self.mps_dtype if device == "mps" else self.dtype
195

196
        torch.random.manual_seed(seed)
197
        pool_size = 2
198
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
199
        if not contiguous:
200
            x = x.permute(0, 1, 3, 2)
201
        rois = torch.tensor(
202
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device  # format is (xyxy)
203
        )
204

205
        def func(z):
206
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
207

208
        script_func = self.get_script_fn(rois, pool_size)
209

210
        with DeterministicGuard(deterministic):
211
            gradcheck(func, (x,), atol=atol)
212

213
        gradcheck(script_func, (x,), atol=atol)
214

215
    @needs_mps
216
    def test_mps_error_inputs(self):
217
        pool_size = 2
218
        x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
219
        rois = torch.tensor(
220
            [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"  # format is (xyxy)
221
        )
222

223
        def func(z):
224
            return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
225

226
        with pytest.raises(
227
            RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
228
        ):
229
            gradcheck(func, (x,))
230

231
    @needs_cuda
232
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
233
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
234
    def test_autocast(self, x_dtype, rois_dtype):
235
        with torch.cuda.amp.autocast():
236
            self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
237

238
    def _helper_boxes_shape(self, func):
239
        # test boxes as Tensor[N, 5]
240
        with pytest.raises(AssertionError):
241
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
242
            boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
243
            func(a, boxes, output_size=(2, 2))
244

245
        # test boxes as List[Tensor[N, 4]]
246
        with pytest.raises(AssertionError):
247
            a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
248
            boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
249
            ops.roi_pool(a, [boxes], output_size=(2, 2))
250

251
    def _helper_jit_boxes_list(self, model):
252
        x = torch.rand(2, 1, 10, 10)
253
        roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
254
        rois = [roi, roi]
255
        scriped = torch.jit.script(model)
256
        y = scriped(x, rois)
257
        assert y.shape == (10, 1, 3, 3)
258

259
    @abstractmethod
260
    def fn(*args, **kwargs):
261
        pass
262

263
    @abstractmethod
264
    def make_obj(*args, **kwargs):
265
        pass
266

267
    @abstractmethod
268
    def get_script_fn(*args, **kwargs):
269
        pass
270

271
    @abstractmethod
272
    def expected_fn(*args, **kwargs):
273
        pass
274

275

276
class TestRoiPool(RoIOpTester):
277
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
278
        return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
279

280
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
281
        obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
282
        return RoIOpTesterModuleWrapper(obj) if wrap else obj
283

284
    def get_script_fn(self, rois, pool_size):
285
        scriped = torch.jit.script(ops.roi_pool)
286
        return lambda x: scriped(x, rois, pool_size)
287

288
    def expected_fn(
289
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
290
    ):
291
        if device is None:
292
            device = torch.device("cpu")
293

294
        n_channels = x.size(1)
295
        y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
296

297
        def get_slice(k, block):
298
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
299

300
        for roi_idx, roi in enumerate(rois):
301
            batch_idx = int(roi[0])
302
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
303
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
304

305
            roi_h, roi_w = roi_x.shape[-2:]
306
            bin_h = roi_h / pool_h
307
            bin_w = roi_w / pool_w
308

309
            for i in range(0, pool_h):
310
                for j in range(0, pool_w):
311
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
312
                    if bin_x.numel() > 0:
313
                        y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
314
        return y
315

316
    def test_boxes_shape(self):
317
        self._helper_boxes_shape(ops.roi_pool)
318

319
    def test_jit_boxes_list(self):
320
        model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
321
        self._helper_jit_boxes_list(model)
322

323

324
class TestPSRoIPool(RoIOpTester):
325
    mps_backward_atol = 5e-2
326

327
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
328
        return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
329

330
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
331
        obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
332
        return RoIOpTesterModuleWrapper(obj) if wrap else obj
333

334
    def get_script_fn(self, rois, pool_size):
335
        scriped = torch.jit.script(ops.ps_roi_pool)
336
        return lambda x: scriped(x, rois, pool_size)
337

338
    def expected_fn(
339
        self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
340
    ):
341
        if device is None:
342
            device = torch.device("cpu")
343
        n_input_channels = x.size(1)
344
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
345
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
346
        y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
347

348
        def get_slice(k, block):
349
            return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
350

351
        for roi_idx, roi in enumerate(rois):
352
            batch_idx = int(roi[0])
353
            j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
354
            roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
355

356
            roi_height = max(i_end - i_begin, 1)
357
            roi_width = max(j_end - j_begin, 1)
358
            bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)
359

360
            for i in range(0, pool_h):
361
                for j in range(0, pool_w):
362
                    bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
363
                    if bin_x.numel() > 0:
364
                        area = bin_x.size(-2) * bin_x.size(-1)
365
                        for c_out in range(0, n_output_channels):
366
                            c_in = c_out * (pool_h * pool_w) + pool_w * i + j
367
                            t = torch.sum(bin_x[c_in, :, :])
368
                            y[roi_idx, c_out, i, j] = t / area
369
        return y
370

371
    def test_boxes_shape(self):
372
        self._helper_boxes_shape(ops.ps_roi_pool)
373

374

375
def bilinear_interpolate(data, y, x, snap_border=False):
376
    height, width = data.shape
377

378
    if snap_border:
379
        if -1 < y <= 0:
380
            y = 0
381
        elif height - 1 <= y < height:
382
            y = height - 1
383

384
        if -1 < x <= 0:
385
            x = 0
386
        elif width - 1 <= x < width:
387
            x = width - 1
388

389
    y_low = int(math.floor(y))
390
    x_low = int(math.floor(x))
391
    y_high = y_low + 1
392
    x_high = x_low + 1
393

394
    wy_h = y - y_low
395
    wx_h = x - x_low
396
    wy_l = 1 - wy_h
397
    wx_l = 1 - wx_h
398

399
    val = 0
400
    for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
401
        for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
402
            if 0 <= yp < height and 0 <= xp < width:
403
                val += wx * wy * data[yp, xp]
404
    return val
405

406

407
class TestRoIAlign(RoIOpTester):
408
    mps_backward_atol = 6e-2
409

410
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
411
        return ops.RoIAlign(
412
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
413
        )(x, rois)
414

415
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
416
        obj = ops.RoIAlign(
417
            (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
418
        )
419
        return RoIOpTesterModuleWrapper(obj) if wrap else obj
420

421
    def get_script_fn(self, rois, pool_size):
422
        scriped = torch.jit.script(ops.roi_align)
423
        return lambda x: scriped(x, rois, pool_size)
424

425
    def expected_fn(
426
        self,
427
        in_data,
428
        rois,
429
        pool_h,
430
        pool_w,
431
        spatial_scale=1,
432
        sampling_ratio=-1,
433
        aligned=False,
434
        device=None,
435
        dtype=torch.float64,
436
    ):
437
        if device is None:
438
            device = torch.device("cpu")
439
        n_channels = in_data.size(1)
440
        out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
441

442
        offset = 0.5 if aligned else 0.0
443

444
        for r, roi in enumerate(rois):
445
            batch_idx = int(roi[0])
446
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
447

448
            roi_h = i_end - i_begin
449
            roi_w = j_end - j_begin
450
            bin_h = roi_h / pool_h
451
            bin_w = roi_w / pool_w
452

453
            for i in range(0, pool_h):
454
                start_h = i_begin + i * bin_h
455
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
456
                for j in range(0, pool_w):
457
                    start_w = j_begin + j * bin_w
458
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
459

460
                    for channel in range(0, n_channels):
461
                        val = 0
462
                        for iy in range(0, grid_h):
463
                            y = start_h + (iy + 0.5) * bin_h / grid_h
464
                            for ix in range(0, grid_w):
465
                                x = start_w + (ix + 0.5) * bin_w / grid_w
466
                                val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
467
                        val /= grid_h * grid_w
468

469
                        out_data[r, channel, i, j] = val
470
        return out_data
471

472
    def test_boxes_shape(self):
473
        self._helper_boxes_shape(ops.roi_align)
474

475
    @pytest.mark.parametrize("aligned", (True, False))
476
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
477
    @pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))  # , ids=str)
478
    @pytest.mark.parametrize("contiguous", (True, False))
479
    @pytest.mark.parametrize("deterministic", (True, False))
480
    @pytest.mark.opcheck_only_one()
481
    def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
482
        if deterministic and device == "cpu":
483
            pytest.skip("cpu is always deterministic, don't retest")
484
        super().test_forward(
485
            device=device,
486
            contiguous=contiguous,
487
            deterministic=deterministic,
488
            x_dtype=x_dtype,
489
            rois_dtype=rois_dtype,
490
            aligned=aligned,
491
        )
492

493
    @needs_cuda
494
    @pytest.mark.parametrize("aligned", (True, False))
495
    @pytest.mark.parametrize("deterministic", (True, False))
496
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
497
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
498
    @pytest.mark.opcheck_only_one()
499
    def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
500
        with torch.cuda.amp.autocast():
501
            self.test_forward(
502
                torch.device("cuda"),
503
                contiguous=False,
504
                deterministic=deterministic,
505
                aligned=aligned,
506
                x_dtype=x_dtype,
507
                rois_dtype=rois_dtype,
508
            )
509

510
    @pytest.mark.skip(reason="1/5000 flaky failure")
511
    @pytest.mark.parametrize("aligned", (True, False))
512
    @pytest.mark.parametrize("deterministic", (True, False))
513
    @pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
514
    @pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
515
    def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
516
        with torch.cpu.amp.autocast():
517
            self.test_forward(
518
                torch.device("cpu"),
519
                contiguous=False,
520
                deterministic=deterministic,
521
                aligned=aligned,
522
                x_dtype=x_dtype,
523
                rois_dtype=rois_dtype,
524
            )
525

526
    @pytest.mark.parametrize("seed", range(10))
527
    @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
528
    @pytest.mark.parametrize("contiguous", (True, False))
529
    @pytest.mark.parametrize("deterministic", (True, False))
530
    @pytest.mark.opcheck_only_one()
531
    def test_backward(self, seed, device, contiguous, deterministic):
532
        if deterministic and device == "cpu":
533
            pytest.skip("cpu is always deterministic, don't retest")
534
        if deterministic and device == "mps":
535
            pytest.skip("no deterministic implementation for mps")
536
        if deterministic and not is_compile_supported(device):
537
            pytest.skip("deterministic implementation only if torch.compile supported")
538
        super().test_backward(seed, device, contiguous, deterministic)
539

540
    def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
541
        rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
542
        rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))  # set batch index
543
        rois[:, 3:] += rois[:, 1:3]  # make sure boxes aren't degenerate
544
        return rois
545

546
    @pytest.mark.parametrize("aligned", (True, False))
547
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
548
    @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
549
    @pytest.mark.opcheck_only_one()
550
    def test_qroialign(self, aligned, scale, zero_point, qdtype):
551
        """Make sure quantized version of RoIAlign is close to float version"""
552
        pool_size = 5
553
        img_size = 10
554
        n_channels = 2
555
        num_imgs = 1
556
        dtype = torch.float
557

558
        x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
559
        qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)
560

561
        rois = self._make_rois(img_size, num_imgs, dtype)
562
        qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)
563

564
        x, rois = qx.dequantize(), qrois.dequantize()  # we want to pass the same inputs
565

566
        y = ops.roi_align(
567
            x,
568
            rois,
569
            output_size=pool_size,
570
            spatial_scale=1,
571
            sampling_ratio=-1,
572
            aligned=aligned,
573
        )
574
        qy = ops.roi_align(
575
            qx,
576
            qrois,
577
            output_size=pool_size,
578
            spatial_scale=1,
579
            sampling_ratio=-1,
580
            aligned=aligned,
581
        )
582

583
        # The output qy is itself a quantized tensor and there might have been a loss of info when it was
584
        # quantized. For a fair comparison we need to quantize y as well
585
        quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)
586

587
        try:
588
            # Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
589
            assert (qy == quantized_float_y).all()
590
        except AssertionError:
591
            # But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
592
            # rounding error may lead to a difference of 2 in the output.
593
            # For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
594
            # but 45.00000001 will be rounded to 46. We make sure below that:
595
            # - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
596
            # - any difference between qy and quantized_float_y is == scale
597
            diff_idx = torch.where(qy != quantized_float_y)
598
            num_diff = diff_idx[0].numel()
599
            assert num_diff / qy.numel() < 0.05
600

601
            abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
602
            t_scale = torch.full_like(abs_diff, fill_value=scale)
603
            torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
604

605
    def test_qroi_align_multiple_images(self):
606
        dtype = torch.float
607
        x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
608
        qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
609
        rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
610
        qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
611
        with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
612
            ops.roi_align(qx, qrois, output_size=5)
613

614
    def test_jit_boxes_list(self):
615
        model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
616
        self._helper_jit_boxes_list(model)
617

618

619
class TestPSRoIAlign(RoIOpTester):
620
    mps_backward_atol = 5e-2
621

622
    def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
623
        return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
624

625
    def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
626
        obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
627
        return RoIOpTesterModuleWrapper(obj) if wrap else obj
628

629
    def get_script_fn(self, rois, pool_size):
630
        scriped = torch.jit.script(ops.ps_roi_align)
631
        return lambda x: scriped(x, rois, pool_size)
632

633
    def expected_fn(
634
        self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
635
    ):
636
        if device is None:
637
            device = torch.device("cpu")
638
        n_input_channels = in_data.size(1)
639
        assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
640
        n_output_channels = int(n_input_channels / (pool_h * pool_w))
641
        out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
642

643
        for r, roi in enumerate(rois):
644
            batch_idx = int(roi[0])
645
            j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])
646

647
            roi_h = i_end - i_begin
648
            roi_w = j_end - j_begin
649
            bin_h = roi_h / pool_h
650
            bin_w = roi_w / pool_w
651

652
            for i in range(0, pool_h):
653
                start_h = i_begin + i * bin_h
654
                grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
655
                for j in range(0, pool_w):
656
                    start_w = j_begin + j * bin_w
657
                    grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
658
                    for c_out in range(0, n_output_channels):
659
                        c_in = c_out * (pool_h * pool_w) + pool_w * i + j
660

661
                        val = 0
662
                        for iy in range(0, grid_h):
663
                            y = start_h + (iy + 0.5) * bin_h / grid_h
664
                            for ix in range(0, grid_w):
665
                                x = start_w + (ix + 0.5) * bin_w / grid_w
666
                                val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
667
                        val /= grid_h * grid_w
668

669
                        out_data[r, c_out, i, j] = val
670
        return out_data
671

672
    def test_boxes_shape(self):
673
        self._helper_boxes_shape(ops.ps_roi_align)
674

675

676
@pytest.mark.parametrize(
677
    "op",
678
    (
679
        torch.ops.torchvision.roi_pool,
680
        torch.ops.torchvision.ps_roi_pool,
681
        torch.ops.torchvision.roi_align,
682
        torch.ops.torchvision.ps_roi_align,
683
    ),
684
)
685
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
686
@pytest.mark.parametrize("device", cpu_and_cuda())
687
@pytest.mark.parametrize("requires_grad", (True, False))
688
def test_roi_opcheck(op, dtype, device, requires_grad):
689
    # This manually calls opcheck() on the roi ops. We do that instead of
690
    # relying on opcheck.generate_opcheck_tests() as e.g. done for nms, because
691
    # pytest and generate_opcheck_tests() don't interact very well when it comes
692
    # to skipping tests - and these ops need to skip the MPS tests since MPS we
693
    # don't support dynamic shapes yet for MPS.
694
    rois = torch.tensor(
695
        [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
696
        dtype=dtype,
697
        device=device,
698
        requires_grad=requires_grad,
699
    )
700
    pool_size = 5
701
    num_channels = 2 * (pool_size**2)
702
    x = torch.rand(2, num_channels, 10, 10, dtype=dtype, device=device)
703

704
    kwargs = dict(rois=rois, spatial_scale=1, pooled_height=pool_size, pooled_width=pool_size)
705
    if op in (torch.ops.torchvision.roi_align, torch.ops.torchvision.ps_roi_align):
706
        kwargs["sampling_ratio"] = -1
707
    if op is torch.ops.torchvision.roi_align:
708
        kwargs["aligned"] = True
709

710
    optests.opcheck(op, args=(x,), kwargs=kwargs)
711

712

713
class TestMultiScaleRoIAlign:
714
    def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
715
        if fmap_names is None:
716
            fmap_names = ["0"]
717
        obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
718
        return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
719

720
    def test_msroialign_repr(self):
721
        fmap_names = ["0"]
722
        output_size = (7, 7)
723
        sampling_ratio = 2
724
        # Pass mock feature map names
725
        t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
726

727
        # Check integrity of object __repr__ attribute
728
        expected_string = (
729
            f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
730
            f"sampling_ratio={sampling_ratio})"
731
        )
732
        assert repr(t) == expected_string
733

734
    @pytest.mark.parametrize("device", cpu_and_cuda())
735
    def test_is_leaf_node(self, device):
736
        op_obj = self.make_obj(wrap=True).to(device=device)
737
        graph_node_names = get_graph_node_names(op_obj)
738

739
        assert len(graph_node_names) == 2
740
        assert len(graph_node_names[0]) == len(graph_node_names[1])
741
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
742

743

744
class TestNMS:
745
    def _reference_nms(self, boxes, scores, iou_threshold):
746
        """
747
        Args:
748
            boxes: boxes in corner-form
749
            scores: probabilities
750
            iou_threshold: intersection over union threshold
751
        Returns:
752
             picked: a list of indexes of the kept boxes
753
        """
754
        picked = []
755
        _, indexes = scores.sort(descending=True)
756
        while len(indexes) > 0:
757
            current = indexes[0]
758
            picked.append(current.item())
759
            if len(indexes) == 1:
760
                break
761
            current_box = boxes[current, :]
762
            indexes = indexes[1:]
763
            rest_boxes = boxes[indexes, :]
764
            iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
765
            indexes = indexes[iou <= iou_threshold]
766

767
        return torch.as_tensor(picked)
768

769
    def _create_tensors_with_iou(self, N, iou_thresh):
770
        # force last box to have a pre-defined iou with the first box
771
        # let b0 be [x0, y0, x1, y1], and b1 be [x0, y0, x1 + d, y1],
772
        # then, in order to satisfy ops.iou(b0, b1) == iou_thresh,
773
        # we need to have d = (x1 - x0) * (1 - iou_thresh) / iou_thresh
774
        # Adjust the threshold upward a bit with the intent of creating
775
        # at least one box that exceeds (barely) the threshold and so
776
        # should be suppressed.
777
        boxes = torch.rand(N, 4) * 100
778
        boxes[:, 2:] += boxes[:, :2]
779
        boxes[-1, :] = boxes[0, :]
780
        x0, y0, x1, y1 = boxes[-1].tolist()
781
        iou_thresh += 1e-5
782
        boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
783
        scores = torch.rand(N)
784
        return boxes, scores
785

786
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
787
    @pytest.mark.parametrize("seed", range(10))
788
    @pytest.mark.opcheck_only_one()
789
    def test_nms_ref(self, iou, seed):
790
        torch.random.manual_seed(seed)
791
        err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
792
        boxes, scores = self._create_tensors_with_iou(1000, iou)
793
        keep_ref = self._reference_nms(boxes, scores, iou)
794
        keep = ops.nms(boxes, scores, iou)
795
        torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
796

797
    def test_nms_input_errors(self):
798
        with pytest.raises(RuntimeError):
799
            ops.nms(torch.rand(4), torch.rand(3), 0.5)
800
        with pytest.raises(RuntimeError):
801
            ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
802
        with pytest.raises(RuntimeError):
803
            ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
804
        with pytest.raises(RuntimeError):
805
            ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)
806

807
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
808
    @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
809
    @pytest.mark.opcheck_only_one()
810
    def test_qnms(self, iou, scale, zero_point):
811
        # Note: we compare qnms vs nms instead of qnms vs reference implementation.
812
        # This is because with the int conversion, the trick used in _create_tensors_with_iou
813
        # doesn't really work (in fact, nms vs reference implem will also fail with ints)
814
        err_msg = "NMS and QNMS give different results for IoU={}"
815
        boxes, scores = self._create_tensors_with_iou(1000, iou)
816
        scores *= 100  # otherwise most scores would be 0 or 1 after int conversion
817

818
        qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
819
        qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
820

821
        boxes = qboxes.dequantize()
822
        scores = qscores.dequantize()
823

824
        keep = ops.nms(boxes, scores, iou)
825
        qkeep = ops.nms(qboxes, qscores, iou)
826

827
        torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
828

829
    @pytest.mark.parametrize(
830
        "device",
831
        (
832
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
833
            pytest.param("mps", marks=pytest.mark.needs_mps),
834
        ),
835
    )
836
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
837
    @pytest.mark.opcheck_only_one()
838
    def test_nms_gpu(self, iou, device, dtype=torch.float64):
839
        dtype = torch.float32 if device == "mps" else dtype
840
        tol = 1e-3 if dtype is torch.half else 1e-5
841
        err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
842

843
        boxes, scores = self._create_tensors_with_iou(1000, iou)
844
        r_cpu = ops.nms(boxes, scores, iou)
845
        r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
846

847
        is_eq = torch.allclose(r_cpu, r_gpu.cpu())
848
        if not is_eq:
849
            # if the indices are not the same, ensure that it's because the scores
850
            # are duplicate
851
            is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
852
        assert is_eq, err_msg.format(iou)
853

854
    @needs_cuda
855
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
856
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
857
    @pytest.mark.opcheck_only_one()
858
    def test_autocast(self, iou, dtype):
859
        with torch.cuda.amp.autocast():
860
            self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
861

862
    @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
863
    @pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
864
    def test_autocast_cpu(self, iou, dtype):
865
        boxes, scores = self._create_tensors_with_iou(1000, iou)
866
        with torch.cpu.amp.autocast():
867
            keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
868
            keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
869
        torch.testing.assert_close(keep_ref_float, keep_dtype)
870

871
    @pytest.mark.parametrize(
872
        "device",
873
        (
874
            pytest.param("cuda", marks=pytest.mark.needs_cuda),
875
            pytest.param("mps", marks=pytest.mark.needs_mps),
876
        ),
877
    )
878
    @pytest.mark.opcheck_only_one()
879
    def test_nms_float16(self, device):
880
        boxes = torch.tensor(
881
            [
882
                [285.3538, 185.5758, 1193.5110, 851.4551],
883
                [285.1472, 188.7374, 1192.4984, 851.0669],
884
                [279.2440, 197.9812, 1189.4746, 849.2019],
885
            ]
886
        ).to(device)
887
        scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
888

889
        iou_thres = 0.2
890
        keep32 = ops.nms(boxes, scores, iou_thres)
891
        keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
892
        assert_equal(keep32, keep16)
893

894
    @pytest.mark.parametrize("seed", range(10))
895
    @pytest.mark.opcheck_only_one()
896
    def test_batched_nms_implementations(self, seed):
897
        """Make sure that both implementations of batched_nms yield identical results"""
898
        torch.random.manual_seed(seed)
899

900
        num_boxes = 1000
901
        iou_threshold = 0.9
902

903
        boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
904
        assert max(boxes[:, 0]) < min(boxes[:, 2])  # x1 < x2
905
        assert max(boxes[:, 1]) < min(boxes[:, 3])  # y1 < y2
906

907
        scores = torch.rand(num_boxes)
908
        idxs = torch.randint(0, 4, size=(num_boxes,))
909
        keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
910
        keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
911

912
        torch.testing.assert_close(
913
            keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
914
        )
915

916
        # Also make sure an empty tensor is returned if boxes is empty
917
        empty = torch.empty((0,), dtype=torch.int64)
918
        torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
919

920

921
optests.generate_opcheck_tests(
922
    testcase=TestNMS,
923
    namespaces=["torchvision"],
924
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
925
    additional_decorators=[],
926
    test_utils=OPTESTS,
927
)
928

929

930
class TestDeformConv:
931
    dtype = torch.float64
932

933
    def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
934
        stride_h, stride_w = _pair(stride)
935
        pad_h, pad_w = _pair(padding)
936
        dil_h, dil_w = _pair(dilation)
937
        weight_h, weight_w = weight.shape[-2:]
938

939
        n_batches, n_in_channels, in_h, in_w = x.shape
940
        n_out_channels = weight.shape[0]
941

942
        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
943
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
944

945
        n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
946
        in_c_per_offset_grp = n_in_channels // n_offset_grps
947

948
        n_weight_grps = n_in_channels // weight.shape[1]
949
        in_c_per_weight_grp = weight.shape[1]
950
        out_c_per_weight_grp = n_out_channels // n_weight_grps
951

952
        out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
953
        for b in range(n_batches):
954
            for c_out in range(n_out_channels):
955
                for i in range(out_h):
956
                    for j in range(out_w):
957
                        for di in range(weight_h):
958
                            for dj in range(weight_w):
959
                                for c in range(in_c_per_weight_grp):
960
                                    weight_grp = c_out // out_c_per_weight_grp
961
                                    c_in = weight_grp * in_c_per_weight_grp + c
962

963
                                    offset_grp = c_in // in_c_per_offset_grp
964
                                    mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
965
                                    offset_idx = 2 * mask_idx
966

967
                                    pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
968
                                    pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
969

970
                                    mask_value = 1.0
971
                                    if mask is not None:
972
                                        mask_value = mask[b, mask_idx, i, j]
973

974
                                    out[b, c_out, i, j] += (
975
                                        mask_value
976
                                        * weight[c_out, c, di, dj]
977
                                        * bilinear_interpolate(x[b, c_in, :, :], pi, pj)
978
                                    )
979
        out += bias.view(1, n_out_channels, 1, 1)
980
        return out
981

982
    @lru_cache(maxsize=None)
983
    def get_fn_args(self, device, contiguous, batch_sz, dtype):
984
        n_in_channels = 6
985
        n_out_channels = 2
986
        n_weight_grps = 2
987
        n_offset_grps = 3
988

989
        stride = (2, 1)
990
        pad = (1, 0)
991
        dilation = (2, 1)
992

993
        stride_h, stride_w = stride
994
        pad_h, pad_w = pad
995
        dil_h, dil_w = dilation
996
        weight_h, weight_w = (3, 2)
997
        in_h, in_w = (5, 4)
998

999
        out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
1000
        out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
1001

1002
        x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
1003

1004
        offset = torch.randn(
1005
            batch_sz,
1006
            n_offset_grps * 2 * weight_h * weight_w,
1007
            out_h,
1008
            out_w,
1009
            device=device,
1010
            dtype=dtype,
1011
            requires_grad=True,
1012
        )
1013

1014
        mask = torch.randn(
1015
            batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
1016
        )
1017

1018
        weight = torch.randn(
1019
            n_out_channels,
1020
            n_in_channels // n_weight_grps,
1021
            weight_h,
1022
            weight_w,
1023
            device=device,
1024
            dtype=dtype,
1025
            requires_grad=True,
1026
        )
1027

1028
        bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
1029

1030
        if not contiguous:
1031
            x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
1032
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1033
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1034
            weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
1035

1036
        return x, weight, offset, mask, bias, stride, pad, dilation
1037

1038
    def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
1039
        obj = ops.DeformConv2d(
1040
            in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
1041
        )
1042
        return DeformConvModuleWrapper(obj) if wrap else obj
1043

1044
    @pytest.mark.parametrize("device", cpu_and_cuda())
1045
    def test_is_leaf_node(self, device):
1046
        op_obj = self.make_obj(wrap=True).to(device=device)
1047
        graph_node_names = get_graph_node_names(op_obj)
1048

1049
        assert len(graph_node_names) == 2
1050
        assert len(graph_node_names[0]) == len(graph_node_names[1])
1051
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1052

1053
    @pytest.mark.parametrize("device", cpu_and_cuda())
1054
    @pytest.mark.parametrize("contiguous", (True, False))
1055
    @pytest.mark.parametrize("batch_sz", (0, 33))
1056
    @pytest.mark.opcheck_only_one()
1057
    def test_forward(self, device, contiguous, batch_sz, dtype=None):
1058
        dtype = dtype or self.dtype
1059
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
1060
        in_channels = 6
1061
        out_channels = 2
1062
        kernel_size = (3, 2)
1063
        groups = 2
1064
        tol = 2e-3 if dtype is torch.half else 1e-5
1065

1066
        layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
1067
            device=x.device, dtype=dtype
1068
        )
1069
        res = layer(x, offset, mask)
1070

1071
        weight = layer.weight.data
1072
        bias = layer.bias.data
1073
        expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
1074

1075
        torch.testing.assert_close(
1076
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1077
        )
1078

1079
        # no modulation test
1080
        res = layer(x, offset)
1081
        expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1082

1083
        torch.testing.assert_close(
1084
            res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1085
        )
1086

1087
    def test_wrong_sizes(self):
1088
        in_channels = 6
1089
        out_channels = 2
1090
        kernel_size = (3, 2)
1091
        groups = 2
1092
        x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
1093
            "cpu", contiguous=True, batch_sz=10, dtype=self.dtype
1094
        )
1095
        layer = ops.DeformConv2d(
1096
            in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
1097
        )
1098
        with pytest.raises(RuntimeError, match="the shape of the offset"):
1099
            wrong_offset = torch.rand_like(offset[:, :2])
1100
            layer(x, wrong_offset)
1101

1102
        with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1103
            wrong_mask = torch.rand_like(mask[:, :2])
1104
            layer(x, offset, wrong_mask)
1105

1106
    @pytest.mark.parametrize("device", cpu_and_cuda())
1107
    @pytest.mark.parametrize("contiguous", (True, False))
1108
    @pytest.mark.parametrize("batch_sz", (0, 33))
1109
    @pytest.mark.opcheck_only_one()
1110
    def test_backward(self, device, contiguous, batch_sz):
1111
        x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
1112
            device, contiguous, batch_sz, self.dtype
1113
        )
1114

1115
        def func(x_, offset_, mask_, weight_, bias_):
1116
            return ops.deform_conv2d(
1117
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
1118
            )
1119

1120
        gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1121

1122
        def func_no_mask(x_, offset_, weight_, bias_):
1123
            return ops.deform_conv2d(
1124
                x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
1125
            )
1126

1127
        gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1128

1129
        @torch.jit.script
1130
        def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
1131
            # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1132
            return ops.deform_conv2d(
1133
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
1134
            )
1135

1136
        gradcheck(
1137
            lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
1138
            (x, offset, mask, weight, bias),
1139
            nondet_tol=1e-5,
1140
            fast_mode=True,
1141
        )
1142

1143
        @torch.jit.script
1144
        def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
1145
            # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
1146
            return ops.deform_conv2d(
1147
                x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
1148
            )
1149

1150
        gradcheck(
1151
            lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
1152
            (x, offset, weight, bias),
1153
            nondet_tol=1e-5,
1154
            fast_mode=True,
1155
        )
1156

1157
    @needs_cuda
1158
    @pytest.mark.parametrize("contiguous", (True, False))
1159
    @pytest.mark.opcheck_only_one()
1160
    def test_compare_cpu_cuda_grads(self, contiguous):
1161
        # Test from https://github.com/pytorch/vision/issues/2598
1162
        # Run on CUDA only
1163

1164
        # compare grads computed on CUDA with grads computed on CPU
1165
        true_cpu_grads = None
1166

1167
        init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
1168
        img = torch.randn(8, 9, 1000, 110)
1169
        offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
1170
        mask = torch.rand(8, 3 * 3, 1000, 110)
1171

1172
        if not contiguous:
1173
            img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
1174
            offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1175
            mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1176
            weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
1177
        else:
1178
            weight = init_weight
1179

1180
        for d in ["cpu", "cuda"]:
1181
            out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
1182
            out.mean().backward()
1183
            if true_cpu_grads is None:
1184
                true_cpu_grads = init_weight.grad
1185
                assert true_cpu_grads is not None
1186
            else:
1187
                assert init_weight.grad is not None
1188
                res_grads = init_weight.grad.to("cpu")
1189
                torch.testing.assert_close(true_cpu_grads, res_grads)
1190

1191
    @needs_cuda
1192
    @pytest.mark.parametrize("batch_sz", (0, 33))
1193
    @pytest.mark.parametrize("dtype", (torch.float, torch.half))
1194
    @pytest.mark.opcheck_only_one()
1195
    def test_autocast(self, batch_sz, dtype):
1196
        with torch.cuda.amp.autocast():
1197
            self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
1198

1199
    def test_forward_scriptability(self):
1200
        # Non-regression test for https://github.com/pytorch/vision/issues/4078
1201
        torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
1202

1203

1204
optests.generate_opcheck_tests(
1205
    testcase=TestDeformConv,
1206
    namespaces=["torchvision"],
1207
    failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
1208
    additional_decorators=[],
1209
    test_utils=OPTESTS,
1210
)
1211

1212

1213
class TestFrozenBNT:
1214
    def test_frozenbatchnorm2d_repr(self):
1215
        num_features = 32
1216
        eps = 1e-5
1217
        t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1218

1219
        # Check integrity of object __repr__ attribute
1220
        expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1221
        assert repr(t) == expected_string
1222

1223
    @pytest.mark.parametrize("seed", range(10))
1224
    def test_frozenbatchnorm2d_eps(self, seed):
1225
        torch.random.manual_seed(seed)
1226
        sample_size = (4, 32, 28, 28)
1227
        x = torch.rand(sample_size)
1228
        state_dict = dict(
1229
            weight=torch.rand(sample_size[1]),
1230
            bias=torch.rand(sample_size[1]),
1231
            running_mean=torch.rand(sample_size[1]),
1232
            running_var=torch.rand(sample_size[1]),
1233
            num_batches_tracked=torch.tensor(100),
1234
        )
1235

1236
        # Check that default eps is equal to the one of BN
1237
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
1238
        fbn.load_state_dict(state_dict, strict=False)
1239
        bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1240
        bn.load_state_dict(state_dict)
1241
        # Difference is expected to fall in an acceptable range
1242
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1243

1244
        # Check computation for eps > 0
1245
        fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
1246
        fbn.load_state_dict(state_dict, strict=False)
1247
        bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
1248
        bn.load_state_dict(state_dict)
1249
        torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1250

1251

1252
class TestBoxConversionToRoi:
1253
    def _get_box_sequences():
1254
        # Define here the argument type of `boxes` supported by region pooling operations
1255
        box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1256
        box_list = [
1257
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
1258
            torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
1259
        ]
1260
        box_tuple = tuple(box_list)
1261
        return box_tensor, box_list, box_tuple
1262

1263
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1264
    def test_check_roi_boxes_shape(self, box_sequence):
1265
        # Ensure common sequences of tensors are supported
1266
        ops._utils.check_roi_boxes_shape(box_sequence)
1267

1268
    @pytest.mark.parametrize("box_sequence", _get_box_sequences())
1269
    def test_convert_boxes_to_roi_format(self, box_sequence):
1270
        # Ensure common sequences of tensors yield the same result
1271
        ref_tensor = None
1272
        if ref_tensor is None:
1273
            ref_tensor = box_sequence
1274
        else:
1275
            assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1276

1277

1278
class TestBoxConvert:
1279
    def test_bbox_same(self):
1280
        box_tensor = torch.tensor(
1281
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1282
        )
1283

1284
        exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1285

1286
        assert exp_xyxy.size() == torch.Size([4, 4])
1287
        assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
1288
        assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
1289
        assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1290

1291
    def test_bbox_xyxy_xywh(self):
1292
        # Simple test convert boxes to xywh and back. Make sure they are same.
1293
        # box_tensor is in x1 y1 x2 y2 format.
1294
        box_tensor = torch.tensor(
1295
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1296
        )
1297
        exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1298

1299
        assert exp_xywh.size() == torch.Size([4, 4])
1300
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1301
        assert_equal(box_xywh, exp_xywh)
1302

1303
        # Reverse conversion
1304
        box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1305
        assert_equal(box_xyxy, box_tensor)
1306

1307
    def test_bbox_xyxy_cxcywh(self):
1308
        # Simple test convert boxes to cxcywh and back. Make sure they are same.
1309
        # box_tensor is in x1 y1 x2 y2 format.
1310
        box_tensor = torch.tensor(
1311
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1312
        )
1313
        exp_cxcywh = torch.tensor(
1314
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
1315
        )
1316

1317
        assert exp_cxcywh.size() == torch.Size([4, 4])
1318
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1319
        assert_equal(box_cxcywh, exp_cxcywh)
1320

1321
        # Reverse conversion
1322
        box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1323
        assert_equal(box_xyxy, box_tensor)
1324

1325
    def test_bbox_xywh_cxcywh(self):
1326
        box_tensor = torch.tensor(
1327
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
1328
        )
1329

1330
        exp_cxcywh = torch.tensor(
1331
            [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
1332
        )
1333

1334
        assert exp_cxcywh.size() == torch.Size([4, 4])
1335
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1336
        assert_equal(box_cxcywh, exp_cxcywh)
1337

1338
        # Reverse conversion
1339
        box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1340
        assert_equal(box_xywh, box_tensor)
1341

1342
    @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
1343
    @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1344
    def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1345
        box_tensor = torch.tensor(
1346
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
1347
        )
1348

1349
        with pytest.raises(ValueError):
1350
            ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1351

1352
    def test_bbox_convert_jit(self):
1353
        box_tensor = torch.tensor(
1354
            [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1355
        )
1356

1357
        scripted_fn = torch.jit.script(ops.box_convert)
1358

1359
        box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1360
        scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
1361
        torch.testing.assert_close(scripted_xywh, box_xywh)
1362

1363
        box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1364
        scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
1365
        torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1366

1367

1368
class TestBoxArea:
1369
    def area_check(self, box, expected, atol=1e-4):
1370
        out = ops.box_area(box)
1371
        torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
1372

1373
    @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
1374
    def test_int_boxes(self, dtype):
1375
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
1376
        expected = torch.tensor([10000, 0], dtype=torch.int32)
1377
        self.area_check(box_tensor, expected)
1378

1379
    @pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
1380
    def test_float_boxes(self, dtype):
1381
        box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
1382
        expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
1383
        self.area_check(box_tensor, expected)
1384

1385
    def test_float16_box(self):
1386
        box_tensor = torch.tensor(
1387
            [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
1388
        )
1389

1390
        expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
1391
        self.area_check(box_tensor, expected, atol=0.01)
1392

1393
    def test_box_area_jit(self):
1394
        box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
1395
        expected = ops.box_area(box_tensor)
1396
        scripted_fn = torch.jit.script(ops.box_area)
1397
        scripted_area = scripted_fn(box_tensor)
1398
        torch.testing.assert_close(scripted_area, expected)
1399

1400

1401
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
1402
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1403
FLOAT_BOXES = [
1404
    [285.3538, 185.5758, 1193.5110, 851.4551],
1405
    [285.1472, 188.7374, 1192.4984, 851.0669],
1406
    [279.2440, 197.9812, 1189.4746, 849.2019],
1407
]
1408

1409

1410
def gen_box(size, dtype=torch.float):
1411
    xy1 = torch.rand((size, 2), dtype=dtype)
1412
    xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
1413
    return torch.cat([xy1, xy2], axis=-1)
1414

1415

1416
class TestIouBase:
1417
    @staticmethod
1418
    def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1419
        for dtype in dtypes:
1420
            actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1421
            actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1422
            expected_box = torch.tensor(expected)
1423
            out = target_fn(actual_box1, actual_box2)
1424
            torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
1425

1426
    @staticmethod
1427
    def _run_jit_test(target_fn: Callable, actual_box: List):
1428
        box_tensor = torch.tensor(actual_box, dtype=torch.float)
1429
        expected = target_fn(box_tensor, box_tensor)
1430
        scripted_fn = torch.jit.script(target_fn)
1431
        scripted_out = scripted_fn(box_tensor, box_tensor)
1432
        torch.testing.assert_close(scripted_out, expected)
1433

1434
    @staticmethod
1435
    def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1436
        N = boxes1.size(0)
1437
        M = boxes2.size(0)
1438
        result = torch.zeros((N, M))
1439
        for i in range(N):
1440
            for j in range(M):
1441
                result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1442
        return result
1443

1444
    @staticmethod
1445
    def _run_cartesian_test(target_fn: Callable):
1446
        boxes1 = gen_box(5)
1447
        boxes2 = gen_box(7)
1448
        a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
1449
        b = target_fn(boxes1, boxes2)
1450
        torch.testing.assert_close(a, b)
1451

1452

1453
class TestBoxIou(TestIouBase):
1454
    int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
1455
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1456

1457
    @pytest.mark.parametrize(
1458
        "actual_box1, actual_box2, dtypes, atol, expected",
1459
        [
1460
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1461
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1462
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1463
        ],
1464
    )
1465
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1466
        self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1467

1468
    def test_iou_jit(self):
1469
        self._run_jit_test(ops.box_iou, INT_BOXES)
1470

1471
    def test_iou_cartesian(self):
1472
        self._run_cartesian_test(ops.box_iou)
1473

1474

1475
class TestGeneralizedBoxIou(TestIouBase):
1476
    int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
1477
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1478

1479
    @pytest.mark.parametrize(
1480
        "actual_box1, actual_box2, dtypes, atol, expected",
1481
        [
1482
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1483
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1484
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1485
        ],
1486
    )
1487
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1488
        self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1489

1490
    def test_iou_jit(self):
1491
        self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1492

1493
    def test_iou_cartesian(self):
1494
        self._run_cartesian_test(ops.generalized_box_iou)
1495

1496

1497
class TestDistanceBoxIoU(TestIouBase):
1498
    int_expected = [
1499
        [1.0000, 0.1875, -0.4444],
1500
        [0.1875, 1.0000, -0.5625],
1501
        [-0.4444, -0.5625, 1.0000],
1502
        [-0.0781, 0.1875, -0.6267],
1503
    ]
1504
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1505

1506
    @pytest.mark.parametrize(
1507
        "actual_box1, actual_box2, dtypes, atol, expected",
1508
        [
1509
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1510
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1511
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1512
        ],
1513
    )
1514
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1515
        self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1516

1517
    def test_iou_jit(self):
1518
        self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1519

1520
    def test_iou_cartesian(self):
1521
        self._run_cartesian_test(ops.distance_box_iou)
1522

1523

1524
class TestCompleteBoxIou(TestIouBase):
1525
    int_expected = [
1526
        [1.0000, 0.1875, -0.4444],
1527
        [0.1875, 1.0000, -0.5625],
1528
        [-0.4444, -0.5625, 1.0000],
1529
        [-0.0781, 0.1875, -0.6267],
1530
    ]
1531
    float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1532

1533
    @pytest.mark.parametrize(
1534
        "actual_box1, actual_box2, dtypes, atol, expected",
1535
        [
1536
            pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1537
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1538
            pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1539
        ],
1540
    )
1541
    def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1542
        self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1543

1544
    def test_iou_jit(self):
1545
        self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1546

1547
    def test_iou_cartesian(self):
1548
        self._run_cartesian_test(ops.complete_box_iou)
1549

1550

1551
def get_boxes(dtype, device):
1552
    box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
1553
    box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
1554
    box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
1555
    box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1556

1557
    box1s = torch.stack([box2, box2], dim=0)
1558
    box2s = torch.stack([box3, box4], dim=0)
1559

1560
    return box1, box2, box3, box4, box1s, box2s
1561

1562

1563
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
1564
    computed_loss = iou_fn(box1, box2, reduction=reduction)
1565
    expected_loss = torch.tensor(expected_loss, device=device)
1566
    torch.testing.assert_close(computed_loss, expected_loss)
1567

1568

1569
def assert_empty_loss(iou_fn, dtype, device):
1570
    box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1571
    box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1572
    loss = iou_fn(box1, box2, reduction="mean")
1573
    loss.backward()
1574
    torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
1575
    assert box1.grad is not None, "box1.grad should not be None after backward is called"
1576
    assert box2.grad is not None, "box2.grad should not be None after backward is called"
1577
    loss = iou_fn(box1, box2, reduction="none")
1578
    assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
1579

1580

1581
class TestGeneralizedBoxIouLoss:
1582
    # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
1583
    @pytest.mark.parametrize("device", cpu_and_cuda())
1584
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1585
    def test_giou_loss(self, dtype, device):
1586
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1587

1588
        # Identical boxes should have loss of 0
1589
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
1590

1591
        # quarter size box inside other box = IoU of 0.25
1592
        assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
1593

1594
        # Two side by side boxes, area=union
1595
        # IoU=0 and GIoU=0 (loss 1.0)
1596
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
1597

1598
        # Two diagonally adjacent boxes, area=2*union
1599
        # IoU=0 and GIoU=-0.5 (loss 1.5)
1600
        assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
1601

1602
        # Test batched loss and reductions
1603
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
1604
        assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
1605

1606
        # Test reduction value
1607
        # reduction value other than ["none", "mean", "sum"] should raise a ValueError
1608
        with pytest.raises(ValueError, match="Invalid"):
1609
            ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")
1610

1611
    @pytest.mark.parametrize("device", cpu_and_cuda())
1612
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1613
    def test_empty_inputs(self, dtype, device):
1614
        assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
1615

1616

1617
class TestCompleteBoxIouLoss:
1618
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1619
    @pytest.mark.parametrize("device", cpu_and_cuda())
1620
    def test_ciou_loss(self, dtype, device):
1621
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1622

1623
        assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
1624
        assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
1625
        assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
1626
        assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
1627
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
1628
        assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
1629

1630
        with pytest.raises(ValueError, match="Invalid"):
1631
            ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")
1632

1633
    @pytest.mark.parametrize("device", cpu_and_cuda())
1634
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1635
    def test_empty_inputs(self, dtype, device):
1636
        assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
1637

1638

1639
class TestDistanceBoxIouLoss:
1640
    @pytest.mark.parametrize("device", cpu_and_cuda())
1641
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1642
    def test_distance_iou_loss(self, dtype, device):
1643
        box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1644

1645
        assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
1646
        assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
1647
        assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
1648
        assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
1649
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
1650
        assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
1651

1652
        with pytest.raises(ValueError, match="Invalid"):
1653
            ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")
1654

1655
    @pytest.mark.parametrize("device", cpu_and_cuda())
1656
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1657
    def test_empty_distance_iou_inputs(self, dtype, device):
1658
        assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
1659

1660

1661
class TestFocalLoss:
1662
    def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
1663
        def logit(p):
1664
            return torch.log(p / (1 - p))
1665

1666
        def generate_tensor_with_range_type(shape, range_type, **kwargs):
1667
            if range_type != "random_binary":
1668
                low, high = {
1669
                    "small": (0.0, 0.2),
1670
                    "big": (0.8, 1.0),
1671
                    "zeros": (0.0, 0.0),
1672
                    "ones": (1.0, 1.0),
1673
                    "random": (0.0, 1.0),
1674
                }[range_type]
1675
                return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
1676
            else:
1677
                return torch.randint(0, 2, shape, **kwargs)
1678

1679
        # This function will return inputs and targets with shape: (shape[0]*9, shape[1])
1680
        inputs = []
1681
        targets = []
1682
        for input_range_type, target_range_type in [
1683
            ("small", "zeros"),
1684
            ("small", "ones"),
1685
            ("small", "random_binary"),
1686
            ("big", "zeros"),
1687
            ("big", "ones"),
1688
            ("big", "random_binary"),
1689
            ("random", "zeros"),
1690
            ("random", "ones"),
1691
            ("random", "random_binary"),
1692
        ]:
1693
            inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
1694
            targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
1695

1696
        return torch.cat(inputs), torch.cat(targets)
1697

1698
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1699
    @pytest.mark.parametrize("gamma", [0, 2])
1700
    @pytest.mark.parametrize("device", cpu_and_cuda())
1701
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1702
    @pytest.mark.parametrize("seed", [0, 1])
1703
    def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
1704
        if device == "cpu" and dtype is torch.half:
1705
            pytest.skip("Currently torch.half is not fully supported on cpu")
1706
        # For testing the ratio with manual calculation, we require the reduction to be "none"
1707
        reduction = "none"
1708
        torch.random.manual_seed(seed)
1709
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1710
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1711
        ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
1712

1713
        assert torch.all(
1714
            focal_loss <= ce_loss
1715
        ), "focal loss must be less or equal to cross entropy loss with same input"
1716

1717
        loss_ratio = (focal_loss / ce_loss).squeeze()
1718
        prob = torch.sigmoid(inputs)
1719
        p_t = prob * targets + (1 - prob) * (1 - targets)
1720
        correct_ratio = (1.0 - p_t) ** gamma
1721
        if alpha >= 0:
1722
            alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
1723
            correct_ratio = correct_ratio * alpha_t
1724

1725
        tol = 1e-3 if dtype is torch.half else 1e-5
1726
        torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
1727

1728
    @pytest.mark.parametrize("reduction", ["mean", "sum"])
1729
    @pytest.mark.parametrize("device", cpu_and_cuda())
1730
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1731
    @pytest.mark.parametrize("seed", [2, 3])
1732
    def test_equal_ce_loss(self, reduction, device, dtype, seed):
1733
        if device == "cpu" and dtype is torch.half:
1734
            pytest.skip("Currently torch.half is not fully supported on cpu")
1735
        # focal loss should be equal ce_loss if alpha=-1 and gamma=0
1736
        alpha = -1
1737
        gamma = 0
1738
        torch.random.manual_seed(seed)
1739
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1740
        inputs_fl = inputs.clone().requires_grad_()
1741
        targets_fl = targets.clone()
1742
        inputs_ce = inputs.clone().requires_grad_()
1743
        targets_ce = targets.clone()
1744
        focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
1745
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
1746

1747
        torch.testing.assert_close(focal_loss, ce_loss)
1748

1749
        focal_loss.backward()
1750
        ce_loss.backward()
1751
        torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
1752

1753
    @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1754
    @pytest.mark.parametrize("gamma", [0, 2])
1755
    @pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1756
    @pytest.mark.parametrize("device", cpu_and_cuda())
1757
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1758
    @pytest.mark.parametrize("seed", [4, 5])
1759
    def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
1760
        if device == "cpu" and dtype is torch.half:
1761
            pytest.skip("Currently torch.half is not fully supported on cpu")
1762
        script_fn = torch.jit.script(ops.sigmoid_focal_loss)
1763
        torch.random.manual_seed(seed)
1764
        inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1765
        focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1766
        scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1767

1768
        tol = 1e-3 if dtype is torch.half else 1e-5
1769
        torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
1770

1771
    # Raise ValueError for anonymous reduction mode
1772
    @pytest.mark.parametrize("device", cpu_and_cuda())
1773
    @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1774
    def test_reduction_mode(self, device, dtype, reduction="xyz"):
1775
        if device == "cpu" and dtype is torch.half:
1776
            pytest.skip("Currently torch.half is not fully supported on cpu")
1777
        torch.random.manual_seed(0)
1778
        inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
1779
        with pytest.raises(ValueError, match="Invalid"):
1780
            ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)
1781

1782

1783
class TestMasksToBoxes:
1784
    def test_masks_box(self):
1785
        def masks_box_check(masks, expected, atol=1e-4):
1786
            out = ops.masks_to_boxes(masks)
1787
            assert out.dtype == torch.float
1788
            torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1789

1790
        # Check for int type boxes.
1791
        def _get_image():
1792
            assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1793
            mask_path = os.path.join(assets_directory, "masks.tiff")
1794
            image = Image.open(mask_path)
1795
            return image
1796

1797
        def _create_masks(image, masks):
1798
            for index in range(image.n_frames):
1799
                image.seek(index)
1800
                frame = np.array(image)
1801
                masks[index] = torch.tensor(frame)
1802

1803
            return masks
1804

1805
        expected = torch.tensor(
1806
            [
1807
                [127, 2, 165, 40],
1808
                [2, 50, 44, 92],
1809
                [56, 63, 98, 100],
1810
                [139, 68, 175, 104],
1811
                [160, 112, 198, 145],
1812
                [49, 138, 99, 182],
1813
                [108, 148, 152, 213],
1814
            ],
1815
            dtype=torch.float,
1816
        )
1817

1818
        image = _get_image()
1819
        for dtype in [torch.float16, torch.float32, torch.float64]:
1820
            masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
1821
            masks = _create_masks(image, masks)
1822
            masks_box_check(masks, expected)
1823

1824

1825
class TestStochasticDepth:
1826
    @pytest.mark.parametrize("seed", range(10))
1827
    @pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
1828
    @pytest.mark.parametrize("mode", ["batch", "row"])
1829
    def test_stochastic_depth_random(self, seed, mode, p):
1830
        torch.manual_seed(seed)
1831
        stats = pytest.importorskip("scipy.stats")
1832
        batch_size = 5
1833
        x = torch.ones(size=(batch_size, 3, 4, 4))
1834
        layer = ops.StochasticDepth(p=p, mode=mode)
1835
        layer.__repr__()
1836

1837
        trials = 250
1838
        num_samples = 0
1839
        counts = 0
1840
        for _ in range(trials):
1841
            out = layer(x)
1842
            non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
1843
            if mode == "batch":
1844
                if non_zero_count == 0:
1845
                    counts += 1
1846
                num_samples += 1
1847
            elif mode == "row":
1848
                counts += batch_size - non_zero_count
1849
                num_samples += batch_size
1850

1851
        p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1852
        assert p_value > 0.01
1853

1854
    @pytest.mark.parametrize("seed", range(10))
1855
    @pytest.mark.parametrize("p", (0, 1))
1856
    @pytest.mark.parametrize("mode", ["batch", "row"])
1857
    def test_stochastic_depth(self, seed, mode, p):
1858
        torch.manual_seed(seed)
1859
        batch_size = 5
1860
        x = torch.ones(size=(batch_size, 3, 4, 4))
1861
        layer = ops.StochasticDepth(p=p, mode=mode)
1862

1863
        out = layer(x)
1864
        if p == 0:
1865
            assert out.equal(x)
1866
        elif p == 1:
1867
            assert out.equal(torch.zeros_like(x))
1868

1869
    def make_obj(self, p, mode, wrap=False):
1870
        obj = ops.StochasticDepth(p, mode)
1871
        return StochasticDepthWrapper(obj) if wrap else obj
1872

1873
    @pytest.mark.parametrize("p", (0, 1))
1874
    @pytest.mark.parametrize("mode", ["batch", "row"])
1875
    def test_is_leaf_node(self, p, mode):
1876
        op_obj = self.make_obj(p, mode, wrap=True)
1877
        graph_node_names = get_graph_node_names(op_obj)
1878

1879
        assert len(graph_node_names) == 2
1880
        assert len(graph_node_names[0]) == len(graph_node_names[1])
1881
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1882

1883

1884
class TestUtils:
1885
    @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
1886
    def test_split_normalization_params(self, norm_layer):
1887
        model = models.mobilenet_v3_large(norm_layer=norm_layer)
1888
        params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
1889

1890
        assert len(params[0]) == 92
1891
        assert len(params[1]) == 82
1892

1893

1894
class TestDropBlock:
1895
    @pytest.mark.parametrize("seed", range(10))
1896
    @pytest.mark.parametrize("dim", [2, 3])
1897
    @pytest.mark.parametrize("p", [0, 0.5])
1898
    @pytest.mark.parametrize("block_size", [5, 11])
1899
    @pytest.mark.parametrize("inplace", [True, False])
1900
    def test_drop_block(self, seed, dim, p, block_size, inplace):
1901
        torch.manual_seed(seed)
1902
        batch_size = 5
1903
        channels = 3
1904
        height = 11
1905
        width = height
1906
        depth = height
1907
        if dim == 2:
1908
            x = torch.ones(size=(batch_size, channels, height, width))
1909
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1910
            feature_size = height * width
1911
        elif dim == 3:
1912
            x = torch.ones(size=(batch_size, channels, depth, height, width))
1913
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1914
            feature_size = depth * height * width
1915
        layer.__repr__()
1916

1917
        out = layer(x)
1918
        if p == 0:
1919
            assert out.equal(x)
1920
        if block_size == height:
1921
            for b, c in product(range(batch_size), range(channels)):
1922
                assert out[b, c].count_nonzero() in (0, feature_size)
1923

1924
    @pytest.mark.parametrize("seed", range(10))
1925
    @pytest.mark.parametrize("dim", [2, 3])
1926
    @pytest.mark.parametrize("p", [0.1, 0.2])
1927
    @pytest.mark.parametrize("block_size", [3])
1928
    @pytest.mark.parametrize("inplace", [False])
1929
    def test_drop_block_random(self, seed, dim, p, block_size, inplace):
1930
        torch.manual_seed(seed)
1931
        batch_size = 5
1932
        channels = 3
1933
        height = 11
1934
        width = height
1935
        depth = height
1936
        if dim == 2:
1937
            x = torch.ones(size=(batch_size, channels, height, width))
1938
            layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1939
        elif dim == 3:
1940
            x = torch.ones(size=(batch_size, channels, depth, height, width))
1941
            layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1942

1943
        trials = 250
1944
        num_samples = 0
1945
        counts = 0
1946
        cell_numel = torch.tensor(x.shape).prod()
1947
        for _ in range(trials):
1948
            with torch.no_grad():
1949
                out = layer(x)
1950
            non_zero_count = out.nonzero().size(0)
1951
            counts += cell_numel - non_zero_count
1952
            num_samples += cell_numel
1953

1954
        assert abs(p - counts / num_samples) / p < 0.15
1955

1956
    def make_obj(self, dim, p, block_size, inplace, wrap=False):
1957
        if dim == 2:
1958
            obj = ops.DropBlock2d(p, block_size, inplace)
1959
        elif dim == 3:
1960
            obj = ops.DropBlock3d(p, block_size, inplace)
1961
        return DropBlockWrapper(obj) if wrap else obj
1962

1963
    @pytest.mark.parametrize("dim", (2, 3))
1964
    @pytest.mark.parametrize("p", [0, 1])
1965
    @pytest.mark.parametrize("block_size", [5, 7])
1966
    @pytest.mark.parametrize("inplace", [True, False])
1967
    def test_is_leaf_node(self, dim, p, block_size, inplace):
1968
        op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
1969
        graph_node_names = get_graph_node_names(op_obj)
1970

1971
        assert len(graph_node_names) == 2
1972
        assert len(graph_node_names[0]) == len(graph_node_names[1])
1973
        assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1974

1975

1976
if __name__ == "__main__":
1977
    pytest.main([__file__])
1978

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.