3
from abc import ABC, abstractmethod
4
from functools import lru_cache
5
from itertools import product
6
from typing import Callable, List, Tuple
12
import torch.nn.functional as F
13
import torch.testing._internal.optests as optests
14
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
16
from torch import nn, Tensor
17
from torch._dynamo.utils import is_compile_supported
18
from torch.autograd import gradcheck
19
from torch.nn.modules.utils import _pair
20
from torchvision import models, ops
21
from torchvision.models.feature_extraction import get_graph_node_names
26
"test_autograd_registration",
28
"test_aot_dispatch_dynamic",
34
class DeterministicGuard:
35
def __init__(self, deterministic, *, warn_only=False):
36
self.deterministic = deterministic
37
self.warn_only = warn_only
40
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
41
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
42
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
44
def __exit__(self, exception_type, exception_value, traceback):
45
torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)
48
class RoIOpTesterModuleWrapper(nn.Module):
49
def __init__(self, obj):
54
def forward(self, a, b):
58
class MultiScaleRoIAlignModuleWrapper(nn.Module):
59
def __init__(self, obj):
64
def forward(self, a, b, c):
68
class DeformConvModuleWrapper(nn.Module):
69
def __init__(self, obj):
74
def forward(self, a, b, c):
78
class StochasticDepthWrapper(nn.Module):
79
def __init__(self, obj):
88
class DropBlockWrapper(nn.Module):
89
def __init__(self, obj):
98
class PoolWrapper(nn.Module):
99
def __init__(self, pool: nn.Module):
103
def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
104
return self.pool(imgs, boxes)
107
class RoIOpTester(ABC):
108
dtype = torch.float64
109
mps_dtype = torch.float32
110
mps_backward_atol = 2e-2
112
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
113
@pytest.mark.parametrize("contiguous", (True, False))
114
@pytest.mark.parametrize(
123
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
124
if device == "mps" and x_dtype is torch.float64:
125
pytest.skip("MPS does not support float64")
127
rois_dtype = x_dtype if rois_dtype is None else rois_dtype
130
if x_dtype is torch.half:
135
elif x_dtype == torch.bfloat16:
140
n_channels = 2 * (pool_size**2)
141
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
143
x = x.permute(0, 1, 3, 2)
145
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
150
pool_h, pool_w = pool_size, pool_size
151
with DeterministicGuard(deterministic):
152
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
154
assert y.dtype == x.dtype
155
gt_y = self.expected_fn(
156
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
159
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
161
@pytest.mark.parametrize("device", cpu_and_cuda())
162
def test_is_leaf_node(self, device):
163
op_obj = self.make_obj(wrap=True).to(device=device)
164
graph_node_names = get_graph_node_names(op_obj)
166
assert len(graph_node_names) == 2
167
assert len(graph_node_names[0]) == len(graph_node_names[1])
168
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
170
@pytest.mark.parametrize("device", cpu_and_cuda())
171
def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
172
op_obj = self.make_obj().to(device=device)
173
graph_module = torch.fx.symbolic_trace(op_obj)
175
n_channels = 2 * (pool_size**2)
176
x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
178
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
182
output_gt = op_obj(x, rois)
183
assert output_gt.dtype == x.dtype
184
output_fx = graph_module(x, rois)
185
assert output_fx.dtype == x.dtype
187
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
189
@pytest.mark.parametrize("seed", range(10))
190
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
191
@pytest.mark.parametrize("contiguous", (True, False))
192
def test_backward(self, seed, device, contiguous, deterministic=False):
193
atol = self.mps_backward_atol if device == "mps" else 1e-05
194
dtype = self.mps_dtype if device == "mps" else self.dtype
196
torch.random.manual_seed(seed)
198
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
200
x = x.permute(0, 1, 3, 2)
202
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device
206
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
208
script_func = self.get_script_fn(rois, pool_size)
210
with DeterministicGuard(deterministic):
211
gradcheck(func, (x,), atol=atol)
213
gradcheck(script_func, (x,), atol=atol)
216
def test_mps_error_inputs(self):
218
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
220
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps"
224
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
227
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
229
gradcheck(func, (x,))
232
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
233
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
234
def test_autocast(self, x_dtype, rois_dtype):
235
with torch.cuda.amp.autocast():
236
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
238
def _helper_boxes_shape(self, func):
240
with pytest.raises(AssertionError):
241
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
242
boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype)
243
func(a, boxes, output_size=(2, 2))
246
with pytest.raises(AssertionError):
247
a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8)
248
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
249
ops.roi_pool(a, [boxes], output_size=(2, 2))
251
def _helper_jit_boxes_list(self, model):
252
x = torch.rand(2, 1, 10, 10)
253
roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
255
scriped = torch.jit.script(model)
257
assert y.shape == (10, 1, 3, 3)
260
def fn(*args, **kwargs):
264
def make_obj(*args, **kwargs):
268
def get_script_fn(*args, **kwargs):
272
def expected_fn(*args, **kwargs):
276
class TestRoiPool(RoIOpTester):
277
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
278
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
280
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
281
obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
282
return RoIOpTesterModuleWrapper(obj) if wrap else obj
284
def get_script_fn(self, rois, pool_size):
285
scriped = torch.jit.script(ops.roi_pool)
286
return lambda x: scriped(x, rois, pool_size)
289
self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
292
device = torch.device("cpu")
294
n_channels = x.size(1)
295
y = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
297
def get_slice(k, block):
298
return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
300
for roi_idx, roi in enumerate(rois):
301
batch_idx = int(roi[0])
302
j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
303
roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
305
roi_h, roi_w = roi_x.shape[-2:]
306
bin_h = roi_h / pool_h
307
bin_w = roi_w / pool_w
309
for i in range(0, pool_h):
310
for j in range(0, pool_w):
311
bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
312
if bin_x.numel() > 0:
313
y[roi_idx, :, i, j] = bin_x.reshape(n_channels, -1).max(dim=1)[0]
316
def test_boxes_shape(self):
317
self._helper_boxes_shape(ops.roi_pool)
319
def test_jit_boxes_list(self):
320
model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
321
self._helper_jit_boxes_list(model)
324
class TestPSRoIPool(RoIOpTester):
325
mps_backward_atol = 5e-2
327
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
328
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
330
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
331
obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
332
return RoIOpTesterModuleWrapper(obj) if wrap else obj
334
def get_script_fn(self, rois, pool_size):
335
scriped = torch.jit.script(ops.ps_roi_pool)
336
return lambda x: scriped(x, rois, pool_size)
339
self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64
342
device = torch.device("cpu")
343
n_input_channels = x.size(1)
344
assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
345
n_output_channels = int(n_input_channels / (pool_h * pool_w))
346
y = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
348
def get_slice(k, block):
349
return slice(int(np.floor(k * block)), int(np.ceil((k + 1) * block)))
351
for roi_idx, roi in enumerate(rois):
352
batch_idx = int(roi[0])
353
j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:])
354
roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1]
356
roi_height = max(i_end - i_begin, 1)
357
roi_width = max(j_end - j_begin, 1)
358
bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)
360
for i in range(0, pool_h):
361
for j in range(0, pool_w):
362
bin_x = roi_x[:, get_slice(i, bin_h), get_slice(j, bin_w)]
363
if bin_x.numel() > 0:
364
area = bin_x.size(-2) * bin_x.size(-1)
365
for c_out in range(0, n_output_channels):
366
c_in = c_out * (pool_h * pool_w) + pool_w * i + j
367
t = torch.sum(bin_x[c_in, :, :])
368
y[roi_idx, c_out, i, j] = t / area
371
def test_boxes_shape(self):
372
self._helper_boxes_shape(ops.ps_roi_pool)
375
def bilinear_interpolate(data, y, x, snap_border=False):
376
height, width = data.shape
381
elif height - 1 <= y < height:
386
elif width - 1 <= x < width:
389
y_low = int(math.floor(y))
390
x_low = int(math.floor(x))
400
for wx, xp in zip((wx_l, wx_h), (x_low, x_high)):
401
for wy, yp in zip((wy_l, wy_h), (y_low, y_high)):
402
if 0 <= yp < height and 0 <= xp < width:
403
val += wx * wy * data[yp, xp]
407
class TestRoIAlign(RoIOpTester):
408
mps_backward_atol = 6e-2
410
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
412
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
415
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
417
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
419
return RoIOpTesterModuleWrapper(obj) if wrap else obj
421
def get_script_fn(self, rois, pool_size):
422
scriped = torch.jit.script(ops.roi_align)
423
return lambda x: scriped(x, rois, pool_size)
438
device = torch.device("cpu")
439
n_channels = in_data.size(1)
440
out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device)
442
offset = 0.5 if aligned else 0.0
444
for r, roi in enumerate(rois):
445
batch_idx = int(roi[0])
446
j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - offset for x in roi[1:])
448
roi_h = i_end - i_begin
449
roi_w = j_end - j_begin
450
bin_h = roi_h / pool_h
451
bin_w = roi_w / pool_w
453
for i in range(0, pool_h):
454
start_h = i_begin + i * bin_h
455
grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
456
for j in range(0, pool_w):
457
start_w = j_begin + j * bin_w
458
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
460
for channel in range(0, n_channels):
462
for iy in range(0, grid_h):
463
y = start_h + (iy + 0.5) * bin_h / grid_h
464
for ix in range(0, grid_w):
465
x = start_w + (ix + 0.5) * bin_w / grid_w
466
val += bilinear_interpolate(in_data[batch_idx, channel, :, :], y, x, snap_border=True)
467
val /= grid_h * grid_w
469
out_data[r, channel, i, j] = val
472
def test_boxes_shape(self):
473
self._helper_boxes_shape(ops.roi_align)
475
@pytest.mark.parametrize("aligned", (True, False))
476
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
477
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64))
478
@pytest.mark.parametrize("contiguous", (True, False))
479
@pytest.mark.parametrize("deterministic", (True, False))
480
@pytest.mark.opcheck_only_one()
481
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
482
if deterministic and device == "cpu":
483
pytest.skip("cpu is always deterministic, don't retest")
484
super().test_forward(
486
contiguous=contiguous,
487
deterministic=deterministic,
489
rois_dtype=rois_dtype,
494
@pytest.mark.parametrize("aligned", (True, False))
495
@pytest.mark.parametrize("deterministic", (True, False))
496
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
497
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
498
@pytest.mark.opcheck_only_one()
499
def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
500
with torch.cuda.amp.autocast():
502
torch.device("cuda"),
504
deterministic=deterministic,
507
rois_dtype=rois_dtype,
510
@pytest.mark.skip(reason="1/5000 flaky failure")
511
@pytest.mark.parametrize("aligned", (True, False))
512
@pytest.mark.parametrize("deterministic", (True, False))
513
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
514
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
515
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
516
with torch.cpu.amp.autocast():
520
deterministic=deterministic,
523
rois_dtype=rois_dtype,
526
@pytest.mark.parametrize("seed", range(10))
527
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
528
@pytest.mark.parametrize("contiguous", (True, False))
529
@pytest.mark.parametrize("deterministic", (True, False))
530
@pytest.mark.opcheck_only_one()
531
def test_backward(self, seed, device, contiguous, deterministic):
532
if deterministic and device == "cpu":
533
pytest.skip("cpu is always deterministic, don't retest")
534
if deterministic and device == "mps":
535
pytest.skip("no deterministic implementation for mps")
536
if deterministic and not is_compile_supported(device):
537
pytest.skip("deterministic implementation only if torch.compile supported")
538
super().test_backward(seed, device, contiguous, deterministic)
540
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
541
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
542
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,))
543
rois[:, 3:] += rois[:, 1:3]
546
@pytest.mark.parametrize("aligned", (True, False))
547
@pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50)))
548
@pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32))
549
@pytest.mark.opcheck_only_one()
550
def test_qroialign(self, aligned, scale, zero_point, qdtype):
551
"""Make sure quantized version of RoIAlign is close to float version"""
558
x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
559
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)
561
rois = self._make_rois(img_size, num_imgs, dtype)
562
qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)
564
x, rois = qx.dequantize(), qrois.dequantize()
569
output_size=pool_size,
577
output_size=pool_size,
585
quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)
589
assert (qy == quantized_float_y).all()
590
except AssertionError:
597
diff_idx = torch.where(qy != quantized_float_y)
598
num_diff = diff_idx[0].numel()
599
assert num_diff / qy.numel() < 0.05
601
abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
602
t_scale = torch.full_like(abs_diff, fill_value=scale)
603
torch.testing.assert_close(abs_diff, t_scale, rtol=1e-5, atol=1e-5)
605
def test_qroi_align_multiple_images(self):
607
x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
608
qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
609
rois = self._make_rois(img_size=10, num_imgs=2, dtype=dtype, num_rois=10)
610
qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
611
with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
612
ops.roi_align(qx, qrois, output_size=5)
614
def test_jit_boxes_list(self):
615
model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
616
self._helper_jit_boxes_list(model)
619
class TestPSRoIAlign(RoIOpTester):
620
mps_backward_atol = 5e-2
622
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
623
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
625
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
626
obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
627
return RoIOpTesterModuleWrapper(obj) if wrap else obj
629
def get_script_fn(self, rois, pool_size):
630
scriped = torch.jit.script(ops.ps_roi_align)
631
return lambda x: scriped(x, rois, pool_size)
634
self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64
637
device = torch.device("cpu")
638
n_input_channels = in_data.size(1)
639
assert n_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
640
n_output_channels = int(n_input_channels / (pool_h * pool_w))
641
out_data = torch.zeros(rois.size(0), n_output_channels, pool_h, pool_w, dtype=dtype, device=device)
643
for r, roi in enumerate(rois):
644
batch_idx = int(roi[0])
645
j_begin, i_begin, j_end, i_end = (x.item() * spatial_scale - 0.5 for x in roi[1:])
647
roi_h = i_end - i_begin
648
roi_w = j_end - j_begin
649
bin_h = roi_h / pool_h
650
bin_w = roi_w / pool_w
652
for i in range(0, pool_h):
653
start_h = i_begin + i * bin_h
654
grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_h))
655
for j in range(0, pool_w):
656
start_w = j_begin + j * bin_w
657
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
658
for c_out in range(0, n_output_channels):
659
c_in = c_out * (pool_h * pool_w) + pool_w * i + j
662
for iy in range(0, grid_h):
663
y = start_h + (iy + 0.5) * bin_h / grid_h
664
for ix in range(0, grid_w):
665
x = start_w + (ix + 0.5) * bin_w / grid_w
666
val += bilinear_interpolate(in_data[batch_idx, c_in, :, :], y, x, snap_border=True)
667
val /= grid_h * grid_w
669
out_data[r, c_out, i, j] = val
672
def test_boxes_shape(self):
673
self._helper_boxes_shape(ops.ps_roi_align)
676
@pytest.mark.parametrize(
679
torch.ops.torchvision.roi_pool,
680
torch.ops.torchvision.ps_roi_pool,
681
torch.ops.torchvision.roi_align,
682
torch.ops.torchvision.ps_roi_align,
685
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
686
@pytest.mark.parametrize("device", cpu_and_cuda())
687
@pytest.mark.parametrize("requires_grad", (True, False))
688
def test_roi_opcheck(op, dtype, device, requires_grad):
695
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]],
698
requires_grad=requires_grad,
701
num_channels = 2 * (pool_size**2)
702
x = torch.rand(2, num_channels, 10, 10, dtype=dtype, device=device)
704
kwargs = dict(rois=rois, spatial_scale=1, pooled_height=pool_size, pooled_width=pool_size)
705
if op in (torch.ops.torchvision.roi_align, torch.ops.torchvision.ps_roi_align):
706
kwargs["sampling_ratio"] = -1
707
if op is torch.ops.torchvision.roi_align:
708
kwargs["aligned"] = True
710
optests.opcheck(op, args=(x,), kwargs=kwargs)
713
class TestMultiScaleRoIAlign:
714
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
715
if fmap_names is None:
717
obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
718
return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
720
def test_msroialign_repr(self):
725
t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
729
f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, "
730
f"sampling_ratio={sampling_ratio})"
732
assert repr(t) == expected_string
734
@pytest.mark.parametrize("device", cpu_and_cuda())
735
def test_is_leaf_node(self, device):
736
op_obj = self.make_obj(wrap=True).to(device=device)
737
graph_node_names = get_graph_node_names(op_obj)
739
assert len(graph_node_names) == 2
740
assert len(graph_node_names[0]) == len(graph_node_names[1])
741
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
745
def _reference_nms(self, boxes, scores, iou_threshold):
748
boxes: boxes in corner-form
749
scores: probabilities
750
iou_threshold: intersection over union threshold
752
picked: a list of indexes of the kept boxes
755
_, indexes = scores.sort(descending=True)
756
while len(indexes) > 0:
758
picked.append(current.item())
759
if len(indexes) == 1:
761
current_box = boxes[current, :]
762
indexes = indexes[1:]
763
rest_boxes = boxes[indexes, :]
764
iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
765
indexes = indexes[iou <= iou_threshold]
767
return torch.as_tensor(picked)
769
def _create_tensors_with_iou(self, N, iou_thresh):
777
boxes = torch.rand(N, 4) * 100
778
boxes[:, 2:] += boxes[:, :2]
779
boxes[-1, :] = boxes[0, :]
780
x0, y0, x1, y1 = boxes[-1].tolist()
782
boxes[-1, 2] += (x1 - x0) * (1 - iou_thresh) / iou_thresh
783
scores = torch.rand(N)
786
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
787
@pytest.mark.parametrize("seed", range(10))
788
@pytest.mark.opcheck_only_one()
789
def test_nms_ref(self, iou, seed):
790
torch.random.manual_seed(seed)
791
err_msg = "NMS incompatible between CPU and reference implementation for IoU={}"
792
boxes, scores = self._create_tensors_with_iou(1000, iou)
793
keep_ref = self._reference_nms(boxes, scores, iou)
794
keep = ops.nms(boxes, scores, iou)
795
torch.testing.assert_close(keep, keep_ref, msg=err_msg.format(iou))
797
def test_nms_input_errors(self):
798
with pytest.raises(RuntimeError):
799
ops.nms(torch.rand(4), torch.rand(3), 0.5)
800
with pytest.raises(RuntimeError):
801
ops.nms(torch.rand(3, 5), torch.rand(3), 0.5)
802
with pytest.raises(RuntimeError):
803
ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5)
804
with pytest.raises(RuntimeError):
805
ops.nms(torch.rand(3, 4), torch.rand(4), 0.5)
807
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
808
@pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10)))
809
@pytest.mark.opcheck_only_one()
810
def test_qnms(self, iou, scale, zero_point):
814
err_msg = "NMS and QNMS give different results for IoU={}"
815
boxes, scores = self._create_tensors_with_iou(1000, iou)
818
qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, dtype=torch.quint8)
819
qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, dtype=torch.quint8)
821
boxes = qboxes.dequantize()
822
scores = qscores.dequantize()
824
keep = ops.nms(boxes, scores, iou)
825
qkeep = ops.nms(qboxes, qscores, iou)
827
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
829
@pytest.mark.parametrize(
832
pytest.param("cuda", marks=pytest.mark.needs_cuda),
833
pytest.param("mps", marks=pytest.mark.needs_mps),
836
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
837
@pytest.mark.opcheck_only_one()
838
def test_nms_gpu(self, iou, device, dtype=torch.float64):
839
dtype = torch.float32 if device == "mps" else dtype
840
tol = 1e-3 if dtype is torch.half else 1e-5
841
err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
843
boxes, scores = self._create_tensors_with_iou(1000, iou)
844
r_cpu = ops.nms(boxes, scores, iou)
845
r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
847
is_eq = torch.allclose(r_cpu, r_gpu.cpu())
851
is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
852
assert is_eq, err_msg.format(iou)
855
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
856
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
857
@pytest.mark.opcheck_only_one()
858
def test_autocast(self, iou, dtype):
859
with torch.cuda.amp.autocast():
860
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
862
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
863
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
864
def test_autocast_cpu(self, iou, dtype):
865
boxes, scores = self._create_tensors_with_iou(1000, iou)
866
with torch.cpu.amp.autocast():
867
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
868
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
869
torch.testing.assert_close(keep_ref_float, keep_dtype)
871
@pytest.mark.parametrize(
874
pytest.param("cuda", marks=pytest.mark.needs_cuda),
875
pytest.param("mps", marks=pytest.mark.needs_mps),
878
@pytest.mark.opcheck_only_one()
879
def test_nms_float16(self, device):
880
boxes = torch.tensor(
882
[285.3538, 185.5758, 1193.5110, 851.4551],
883
[285.1472, 188.7374, 1192.4984, 851.0669],
884
[279.2440, 197.9812, 1189.4746, 849.2019],
887
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
890
keep32 = ops.nms(boxes, scores, iou_thres)
891
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
892
assert_equal(keep32, keep16)
894
@pytest.mark.parametrize("seed", range(10))
895
@pytest.mark.opcheck_only_one()
896
def test_batched_nms_implementations(self, seed):
897
"""Make sure that both implementations of batched_nms yield identical results"""
898
torch.random.manual_seed(seed)
903
boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1)
904
assert max(boxes[:, 0]) < min(boxes[:, 2])
905
assert max(boxes[:, 1]) < min(boxes[:, 3])
907
scores = torch.rand(num_boxes)
908
idxs = torch.randint(0, 4, size=(num_boxes,))
909
keep_vanilla = ops.boxes._batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
910
keep_trick = ops.boxes._batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
912
torch.testing.assert_close(
913
keep_vanilla, keep_trick, msg="The vanilla and the trick implementation yield different nms outputs."
917
empty = torch.empty((0,), dtype=torch.int64)
918
torch.testing.assert_close(empty, ops.batched_nms(empty, None, None, None))
921
optests.generate_opcheck_tests(
923
namespaces=["torchvision"],
924
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
925
additional_decorators=[],
931
dtype = torch.float64
933
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
934
stride_h, stride_w = _pair(stride)
935
pad_h, pad_w = _pair(padding)
936
dil_h, dil_w = _pair(dilation)
937
weight_h, weight_w = weight.shape[-2:]
939
n_batches, n_in_channels, in_h, in_w = x.shape
940
n_out_channels = weight.shape[0]
942
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
943
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
945
n_offset_grps = offset.shape[1] // (2 * weight_h * weight_w)
946
in_c_per_offset_grp = n_in_channels // n_offset_grps
948
n_weight_grps = n_in_channels // weight.shape[1]
949
in_c_per_weight_grp = weight.shape[1]
950
out_c_per_weight_grp = n_out_channels // n_weight_grps
952
out = torch.zeros(n_batches, n_out_channels, out_h, out_w, device=x.device, dtype=x.dtype)
953
for b in range(n_batches):
954
for c_out in range(n_out_channels):
955
for i in range(out_h):
956
for j in range(out_w):
957
for di in range(weight_h):
958
for dj in range(weight_w):
959
for c in range(in_c_per_weight_grp):
960
weight_grp = c_out // out_c_per_weight_grp
961
c_in = weight_grp * in_c_per_weight_grp + c
963
offset_grp = c_in // in_c_per_offset_grp
964
mask_idx = offset_grp * (weight_h * weight_w) + di * weight_w + dj
965
offset_idx = 2 * mask_idx
967
pi = stride_h * i - pad_h + dil_h * di + offset[b, offset_idx, i, j]
968
pj = stride_w * j - pad_w + dil_w * dj + offset[b, offset_idx + 1, i, j]
972
mask_value = mask[b, mask_idx, i, j]
974
out[b, c_out, i, j] += (
976
* weight[c_out, c, di, dj]
977
* bilinear_interpolate(x[b, c_in, :, :], pi, pj)
979
out += bias.view(1, n_out_channels, 1, 1)
982
@lru_cache(maxsize=None)
983
def get_fn_args(self, device, contiguous, batch_sz, dtype):
993
stride_h, stride_w = stride
995
dil_h, dil_w = dilation
996
weight_h, weight_w = (3, 2)
999
out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1) + 1)) // stride_h + 1
1000
out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1) + 1)) // stride_w + 1
1002
x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True)
1004
offset = torch.randn(
1006
n_offset_grps * 2 * weight_h * weight_w,
1015
batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True
1018
weight = torch.randn(
1020
n_in_channels // n_weight_grps,
1028
bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True)
1031
x = x.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
1032
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1033
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1034
weight = weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
1036
return x, weight, offset, mask, bias, stride, pad, dilation
1038
def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
1039
obj = ops.DeformConv2d(
1040
in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
1042
return DeformConvModuleWrapper(obj) if wrap else obj
1044
@pytest.mark.parametrize("device", cpu_and_cuda())
1045
def test_is_leaf_node(self, device):
1046
op_obj = self.make_obj(wrap=True).to(device=device)
1047
graph_node_names = get_graph_node_names(op_obj)
1049
assert len(graph_node_names) == 2
1050
assert len(graph_node_names[0]) == len(graph_node_names[1])
1051
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1053
@pytest.mark.parametrize("device", cpu_and_cuda())
1054
@pytest.mark.parametrize("contiguous", (True, False))
1055
@pytest.mark.parametrize("batch_sz", (0, 33))
1056
@pytest.mark.opcheck_only_one()
1057
def test_forward(self, device, contiguous, batch_sz, dtype=None):
1058
dtype = dtype or self.dtype
1059
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
1062
kernel_size = (3, 2)
1064
tol = 2e-3 if dtype is torch.half else 1e-5
1066
layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
1067
device=x.device, dtype=dtype
1069
res = layer(x, offset, mask)
1071
weight = layer.weight.data
1072
bias = layer.bias.data
1073
expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation)
1075
torch.testing.assert_close(
1076
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1080
res = layer(x, offset)
1081
expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation)
1083
torch.testing.assert_close(
1084
res.to(expected), expected, rtol=tol, atol=tol, msg=f"\nres:\n{res}\nexpected:\n{expected}"
1087
def test_wrong_sizes(self):
1090
kernel_size = (3, 2)
1092
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(
1093
"cpu", contiguous=True, batch_sz=10, dtype=self.dtype
1095
layer = ops.DeformConv2d(
1096
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
1098
with pytest.raises(RuntimeError, match="the shape of the offset"):
1099
wrong_offset = torch.rand_like(offset[:, :2])
1100
layer(x, wrong_offset)
1102
with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"):
1103
wrong_mask = torch.rand_like(mask[:, :2])
1104
layer(x, offset, wrong_mask)
1106
@pytest.mark.parametrize("device", cpu_and_cuda())
1107
@pytest.mark.parametrize("contiguous", (True, False))
1108
@pytest.mark.parametrize("batch_sz", (0, 33))
1109
@pytest.mark.opcheck_only_one()
1110
def test_backward(self, device, contiguous, batch_sz):
1111
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
1112
device, contiguous, batch_sz, self.dtype
1115
def func(x_, offset_, mask_, weight_, bias_):
1116
return ops.deform_conv2d(
1117
x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_
1120
gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True)
1122
def func_no_mask(x_, offset_, weight_, bias_):
1123
return ops.deform_conv2d(
1124
x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None
1127
gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True)
1130
def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_):
1132
return ops.deform_conv2d(
1133
x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_
1137
lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation),
1138
(x, offset, mask, weight, bias),
1144
def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
1146
return ops.deform_conv2d(
1147
x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None
1151
lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation),
1152
(x, offset, weight, bias),
1158
@pytest.mark.parametrize("contiguous", (True, False))
1159
@pytest.mark.opcheck_only_one()
1160
def test_compare_cpu_cuda_grads(self, contiguous):
1165
true_cpu_grads = None
1167
init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
1168
img = torch.randn(8, 9, 1000, 110)
1169
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)
1170
mask = torch.rand(8, 3 * 3, 1000, 110)
1173
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
1174
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1175
mask = mask.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
1176
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
1178
weight = init_weight
1180
for d in ["cpu", "cuda"]:
1181
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
1182
out.mean().backward()
1183
if true_cpu_grads is None:
1184
true_cpu_grads = init_weight.grad
1185
assert true_cpu_grads is not None
1187
assert init_weight.grad is not None
1188
res_grads = init_weight.grad.to("cpu")
1189
torch.testing.assert_close(true_cpu_grads, res_grads)
1192
@pytest.mark.parametrize("batch_sz", (0, 33))
1193
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
1194
@pytest.mark.opcheck_only_one()
1195
def test_autocast(self, batch_sz, dtype):
1196
with torch.cuda.amp.autocast():
1197
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
1199
def test_forward_scriptability(self):
1201
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
1204
optests.generate_opcheck_tests(
1205
testcase=TestDeformConv,
1206
namespaces=["torchvision"],
1207
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
1208
additional_decorators=[],
1214
def test_frozenbatchnorm2d_repr(self):
1217
t = ops.misc.FrozenBatchNorm2d(num_features, eps=eps)
1220
expected_string = f"FrozenBatchNorm2d({num_features}, eps={eps})"
1221
assert repr(t) == expected_string
1223
@pytest.mark.parametrize("seed", range(10))
1224
def test_frozenbatchnorm2d_eps(self, seed):
1225
torch.random.manual_seed(seed)
1226
sample_size = (4, 32, 28, 28)
1227
x = torch.rand(sample_size)
1229
weight=torch.rand(sample_size[1]),
1230
bias=torch.rand(sample_size[1]),
1231
running_mean=torch.rand(sample_size[1]),
1232
running_var=torch.rand(sample_size[1]),
1233
num_batches_tracked=torch.tensor(100),
1237
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1])
1238
fbn.load_state_dict(state_dict, strict=False)
1239
bn = torch.nn.BatchNorm2d(sample_size[1]).eval()
1240
bn.load_state_dict(state_dict)
1242
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1245
fbn = ops.misc.FrozenBatchNorm2d(sample_size[1], eps=1e-5)
1246
fbn.load_state_dict(state_dict, strict=False)
1247
bn = torch.nn.BatchNorm2d(sample_size[1], eps=1e-5).eval()
1248
bn.load_state_dict(state_dict)
1249
torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6)
1252
class TestBoxConversionToRoi:
1253
def _get_box_sequences():
1255
box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float)
1257
torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
1258
torch.tensor([[0, 0, 100, 100]], dtype=torch.float),
1260
box_tuple = tuple(box_list)
1261
return box_tensor, box_list, box_tuple
1263
@pytest.mark.parametrize("box_sequence", _get_box_sequences())
1264
def test_check_roi_boxes_shape(self, box_sequence):
1266
ops._utils.check_roi_boxes_shape(box_sequence)
1268
@pytest.mark.parametrize("box_sequence", _get_box_sequences())
1269
def test_convert_boxes_to_roi_format(self, box_sequence):
1272
if ref_tensor is None:
1273
ref_tensor = box_sequence
1275
assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence))
1278
class TestBoxConvert:
1279
def test_bbox_same(self):
1280
box_tensor = torch.tensor(
1281
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1284
exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
1286
assert exp_xyxy.size() == torch.Size([4, 4])
1287
assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy)
1288
assert_equal(ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="xywh"), exp_xyxy)
1289
assert_equal(ops.box_convert(box_tensor, in_fmt="cxcywh", out_fmt="cxcywh"), exp_xyxy)
1291
def test_bbox_xyxy_xywh(self):
1294
box_tensor = torch.tensor(
1295
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1297
exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float)
1299
assert exp_xywh.size() == torch.Size([4, 4])
1300
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1301
assert_equal(box_xywh, exp_xywh)
1304
box_xyxy = ops.box_convert(box_xywh, in_fmt="xywh", out_fmt="xyxy")
1305
assert_equal(box_xyxy, box_tensor)
1307
def test_bbox_xyxy_cxcywh(self):
1310
box_tensor = torch.tensor(
1311
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1313
exp_cxcywh = torch.tensor(
1314
[[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
1317
assert exp_cxcywh.size() == torch.Size([4, 4])
1318
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1319
assert_equal(box_cxcywh, exp_cxcywh)
1322
box_xyxy = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xyxy")
1323
assert_equal(box_xyxy, box_tensor)
1325
def test_bbox_xywh_cxcywh(self):
1326
box_tensor = torch.tensor(
1327
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
1330
exp_cxcywh = torch.tensor(
1331
[[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float
1334
assert exp_cxcywh.size() == torch.Size([4, 4])
1335
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh")
1336
assert_equal(box_cxcywh, exp_cxcywh)
1339
box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh")
1340
assert_equal(box_xywh, box_tensor)
1342
@pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"])
1343
@pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"])
1344
def test_bbox_invalid(self, inv_infmt, inv_outfmt):
1345
box_tensor = torch.tensor(
1346
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float
1349
with pytest.raises(ValueError):
1350
ops.box_convert(box_tensor, inv_infmt, inv_outfmt)
1352
def test_bbox_convert_jit(self):
1353
box_tensor = torch.tensor(
1354
[[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float
1357
scripted_fn = torch.jit.script(ops.box_convert)
1359
box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh")
1360
scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh")
1361
torch.testing.assert_close(scripted_xywh, box_xywh)
1363
box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh")
1364
scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh")
1365
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
1369
def area_check(self, box, expected, atol=1e-4):
1370
out = ops.box_area(box)
1371
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
1373
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
1374
def test_int_boxes(self, dtype):
1375
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
1376
expected = torch.tensor([10000, 0], dtype=torch.int32)
1377
self.area_check(box_tensor, expected)
1379
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
1380
def test_float_boxes(self, dtype):
1381
box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype)
1382
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype)
1383
self.area_check(box_tensor, expected)
1385
def test_float16_box(self):
1386
box_tensor = torch.tensor(
1387
[[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16
1390
expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16)
1391
self.area_check(box_tensor, expected, atol=0.01)
1393
def test_box_area_jit(self):
1394
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
1395
expected = ops.box_area(box_tensor)
1396
scripted_fn = torch.jit.script(ops.box_area)
1397
scripted_area = scripted_fn(box_tensor)
1398
torch.testing.assert_close(scripted_area, expected)
1401
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
1402
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1404
[285.3538, 185.5758, 1193.5110, 851.4551],
1405
[285.1472, 188.7374, 1192.4984, 851.0669],
1406
[279.2440, 197.9812, 1189.4746, 849.2019],
1410
def gen_box(size, dtype=torch.float):
1411
xy1 = torch.rand((size, 2), dtype=dtype)
1412
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
1413
return torch.cat([xy1, xy2], axis=-1)
1418
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1419
for dtype in dtypes:
1420
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1421
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1422
expected_box = torch.tensor(expected)
1423
out = target_fn(actual_box1, actual_box2)
1424
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
1427
def _run_jit_test(target_fn: Callable, actual_box: List):
1428
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1429
expected = target_fn(box_tensor, box_tensor)
1430
scripted_fn = torch.jit.script(target_fn)
1431
scripted_out = scripted_fn(box_tensor, box_tensor)
1432
torch.testing.assert_close(scripted_out, expected)
1435
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1438
result = torch.zeros((N, M))
1441
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1445
def _run_cartesian_test(target_fn: Callable):
1448
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
1449
b = target_fn(boxes1, boxes2)
1450
torch.testing.assert_close(a, b)
1453
class TestBoxIou(TestIouBase):
1454
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
1455
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1457
@pytest.mark.parametrize(
1458
"actual_box1, actual_box2, dtypes, atol, expected",
1460
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1461
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1462
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1465
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1466
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1468
def test_iou_jit(self):
1469
self._run_jit_test(ops.box_iou, INT_BOXES)
1471
def test_iou_cartesian(self):
1472
self._run_cartesian_test(ops.box_iou)
1475
class TestGeneralizedBoxIou(TestIouBase):
1476
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
1477
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1479
@pytest.mark.parametrize(
1480
"actual_box1, actual_box2, dtypes, atol, expected",
1482
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1483
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1484
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1487
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1488
self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1490
def test_iou_jit(self):
1491
self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
1493
def test_iou_cartesian(self):
1494
self._run_cartesian_test(ops.generalized_box_iou)
1497
class TestDistanceBoxIoU(TestIouBase):
1499
[1.0000, 0.1875, -0.4444],
1500
[0.1875, 1.0000, -0.5625],
1501
[-0.4444, -0.5625, 1.0000],
1502
[-0.0781, 0.1875, -0.6267],
1504
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1506
@pytest.mark.parametrize(
1507
"actual_box1, actual_box2, dtypes, atol, expected",
1509
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1510
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1511
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1514
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1515
self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1517
def test_iou_jit(self):
1518
self._run_jit_test(ops.distance_box_iou, INT_BOXES)
1520
def test_iou_cartesian(self):
1521
self._run_cartesian_test(ops.distance_box_iou)
1524
class TestCompleteBoxIou(TestIouBase):
1526
[1.0000, 0.1875, -0.4444],
1527
[0.1875, 1.0000, -0.5625],
1528
[-0.4444, -0.5625, 1.0000],
1529
[-0.0781, 0.1875, -0.6267],
1531
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1533
@pytest.mark.parametrize(
1534
"actual_box1, actual_box2, dtypes, atol, expected",
1536
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1537
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1538
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1541
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1542
self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
1544
def test_iou_jit(self):
1545
self._run_jit_test(ops.complete_box_iou, INT_BOXES)
1547
def test_iou_cartesian(self):
1548
self._run_cartesian_test(ops.complete_box_iou)
1551
def get_boxes(dtype, device):
1552
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
1553
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
1554
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
1555
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1557
box1s = torch.stack([box2, box2], dim=0)
1558
box2s = torch.stack([box3, box4], dim=0)
1560
return box1, box2, box3, box4, box1s, box2s
1563
def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"):
1564
computed_loss = iou_fn(box1, box2, reduction=reduction)
1565
expected_loss = torch.tensor(expected_loss, device=device)
1566
torch.testing.assert_close(computed_loss, expected_loss)
1569
def assert_empty_loss(iou_fn, dtype, device):
1570
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1571
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1572
loss = iou_fn(box1, box2, reduction="mean")
1574
torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
1575
assert box1.grad is not None, "box1.grad should not be None after backward is called"
1576
assert box2.grad is not None, "box2.grad should not be None after backward is called"
1577
loss = iou_fn(box1, box2, reduction="none")
1578
assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"
1581
class TestGeneralizedBoxIouLoss:
1583
@pytest.mark.parametrize("device", cpu_and_cuda())
1584
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1585
def test_giou_loss(self, dtype, device):
1586
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1589
assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device)
1592
assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device)
1596
assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device)
1600
assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device)
1603
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
1604
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")
1608
with pytest.raises(ValueError, match="Invalid"):
1609
ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")
1611
@pytest.mark.parametrize("device", cpu_and_cuda())
1612
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1613
def test_empty_inputs(self, dtype, device):
1614
assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)
1617
class TestCompleteBoxIouLoss:
1618
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1619
@pytest.mark.parametrize("device", cpu_and_cuda())
1620
def test_ciou_loss(self, dtype, device):
1621
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1623
assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device)
1624
assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device)
1625
assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device)
1626
assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device)
1627
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
1628
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
1630
with pytest.raises(ValueError, match="Invalid"):
1631
ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")
1633
@pytest.mark.parametrize("device", cpu_and_cuda())
1634
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1635
def test_empty_inputs(self, dtype, device):
1636
assert_empty_loss(ops.complete_box_iou_loss, dtype, device)
1639
class TestDistanceBoxIouLoss:
1640
@pytest.mark.parametrize("device", cpu_and_cuda())
1641
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1642
def test_distance_iou_loss(self, dtype, device):
1643
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
1645
assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device)
1646
assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device)
1647
assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device)
1648
assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device)
1649
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
1650
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")
1652
with pytest.raises(ValueError, match="Invalid"):
1653
ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")
1655
@pytest.mark.parametrize("device", cpu_and_cuda())
1656
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1657
def test_empty_distance_iou_inputs(self, dtype, device):
1658
assert_empty_loss(ops.distance_box_iou_loss, dtype, device)
1662
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
1664
return torch.log(p / (1 - p))
1666
def generate_tensor_with_range_type(shape, range_type, **kwargs):
1667
if range_type != "random_binary":
1669
"small": (0.0, 0.2),
1671
"zeros": (0.0, 0.0),
1673
"random": (0.0, 1.0),
1675
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
1677
return torch.randint(0, 2, shape, **kwargs)
1682
for input_range_type, target_range_type in [
1685
("small", "random_binary"),
1688
("big", "random_binary"),
1689
("random", "zeros"),
1691
("random", "random_binary"),
1693
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
1694
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
1696
return torch.cat(inputs), torch.cat(targets)
1698
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1699
@pytest.mark.parametrize("gamma", [0, 2])
1700
@pytest.mark.parametrize("device", cpu_and_cuda())
1701
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1702
@pytest.mark.parametrize("seed", [0, 1])
1703
def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
1704
if device == "cpu" and dtype is torch.half:
1705
pytest.skip("Currently torch.half is not fully supported on cpu")
1708
torch.random.manual_seed(seed)
1709
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1710
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1711
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
1714
focal_loss <= ce_loss
1715
), "focal loss must be less or equal to cross entropy loss with same input"
1717
loss_ratio = (focal_loss / ce_loss).squeeze()
1718
prob = torch.sigmoid(inputs)
1719
p_t = prob * targets + (1 - prob) * (1 - targets)
1720
correct_ratio = (1.0 - p_t) ** gamma
1722
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
1723
correct_ratio = correct_ratio * alpha_t
1725
tol = 1e-3 if dtype is torch.half else 1e-5
1726
torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol)
1728
@pytest.mark.parametrize("reduction", ["mean", "sum"])
1729
@pytest.mark.parametrize("device", cpu_and_cuda())
1730
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1731
@pytest.mark.parametrize("seed", [2, 3])
1732
def test_equal_ce_loss(self, reduction, device, dtype, seed):
1733
if device == "cpu" and dtype is torch.half:
1734
pytest.skip("Currently torch.half is not fully supported on cpu")
1738
torch.random.manual_seed(seed)
1739
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1740
inputs_fl = inputs.clone().requires_grad_()
1741
targets_fl = targets.clone()
1742
inputs_ce = inputs.clone().requires_grad_()
1743
targets_ce = targets.clone()
1744
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
1745
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
1747
torch.testing.assert_close(focal_loss, ce_loss)
1749
focal_loss.backward()
1751
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)
1753
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1754
@pytest.mark.parametrize("gamma", [0, 2])
1755
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1756
@pytest.mark.parametrize("device", cpu_and_cuda())
1757
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1758
@pytest.mark.parametrize("seed", [4, 5])
1759
def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
1760
if device == "cpu" and dtype is torch.half:
1761
pytest.skip("Currently torch.half is not fully supported on cpu")
1762
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
1763
torch.random.manual_seed(seed)
1764
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1765
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1766
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1768
tol = 1e-3 if dtype is torch.half else 1e-5
1769
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
1772
@pytest.mark.parametrize("device", cpu_and_cuda())
1773
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1774
def test_reduction_mode(self, device, dtype, reduction="xyz"):
1775
if device == "cpu" and dtype is torch.half:
1776
pytest.skip("Currently torch.half is not fully supported on cpu")
1777
torch.random.manual_seed(0)
1778
inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
1779
with pytest.raises(ValueError, match="Invalid"):
1780
ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)
1783
class TestMasksToBoxes:
1784
def test_masks_box(self):
1785
def masks_box_check(masks, expected, atol=1e-4):
1786
out = ops.masks_to_boxes(masks)
1787
assert out.dtype == torch.float
1788
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol)
1792
assets_directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
1793
mask_path = os.path.join(assets_directory, "masks.tiff")
1794
image = Image.open(mask_path)
1797
def _create_masks(image, masks):
1798
for index in range(image.n_frames):
1800
frame = np.array(image)
1801
masks[index] = torch.tensor(frame)
1805
expected = torch.tensor(
1810
[139, 68, 175, 104],
1811
[160, 112, 198, 145],
1813
[108, 148, 152, 213],
1818
image = _get_image()
1819
for dtype in [torch.float16, torch.float32, torch.float64]:
1820
masks = torch.zeros((image.n_frames, image.height, image.width), dtype=dtype)
1821
masks = _create_masks(image, masks)
1822
masks_box_check(masks, expected)
1825
class TestStochasticDepth:
1826
@pytest.mark.parametrize("seed", range(10))
1827
@pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
1828
@pytest.mark.parametrize("mode", ["batch", "row"])
1829
def test_stochastic_depth_random(self, seed, mode, p):
1830
torch.manual_seed(seed)
1831
stats = pytest.importorskip("scipy.stats")
1833
x = torch.ones(size=(batch_size, 3, 4, 4))
1834
layer = ops.StochasticDepth(p=p, mode=mode)
1840
for _ in range(trials):
1842
non_zero_count = out.sum(dim=(1, 2, 3)).nonzero().size(0)
1844
if non_zero_count == 0:
1848
counts += batch_size - non_zero_count
1849
num_samples += batch_size
1851
p_value = stats.binomtest(counts, num_samples, p=p).pvalue
1852
assert p_value > 0.01
1854
@pytest.mark.parametrize("seed", range(10))
1855
@pytest.mark.parametrize("p", (0, 1))
1856
@pytest.mark.parametrize("mode", ["batch", "row"])
1857
def test_stochastic_depth(self, seed, mode, p):
1858
torch.manual_seed(seed)
1860
x = torch.ones(size=(batch_size, 3, 4, 4))
1861
layer = ops.StochasticDepth(p=p, mode=mode)
1867
assert out.equal(torch.zeros_like(x))
1869
def make_obj(self, p, mode, wrap=False):
1870
obj = ops.StochasticDepth(p, mode)
1871
return StochasticDepthWrapper(obj) if wrap else obj
1873
@pytest.mark.parametrize("p", (0, 1))
1874
@pytest.mark.parametrize("mode", ["batch", "row"])
1875
def test_is_leaf_node(self, p, mode):
1876
op_obj = self.make_obj(p, mode, wrap=True)
1877
graph_node_names = get_graph_node_names(op_obj)
1879
assert len(graph_node_names) == 2
1880
assert len(graph_node_names[0]) == len(graph_node_names[1])
1881
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1885
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
1886
def test_split_normalization_params(self, norm_layer):
1887
model = models.mobilenet_v3_large(norm_layer=norm_layer)
1888
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
1890
assert len(params[0]) == 92
1891
assert len(params[1]) == 82
1895
@pytest.mark.parametrize("seed", range(10))
1896
@pytest.mark.parametrize("dim", [2, 3])
1897
@pytest.mark.parametrize("p", [0, 0.5])
1898
@pytest.mark.parametrize("block_size", [5, 11])
1899
@pytest.mark.parametrize("inplace", [True, False])
1900
def test_drop_block(self, seed, dim, p, block_size, inplace):
1901
torch.manual_seed(seed)
1908
x = torch.ones(size=(batch_size, channels, height, width))
1909
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1910
feature_size = height * width
1912
x = torch.ones(size=(batch_size, channels, depth, height, width))
1913
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1914
feature_size = depth * height * width
1920
if block_size == height:
1921
for b, c in product(range(batch_size), range(channels)):
1922
assert out[b, c].count_nonzero() in (0, feature_size)
1924
@pytest.mark.parametrize("seed", range(10))
1925
@pytest.mark.parametrize("dim", [2, 3])
1926
@pytest.mark.parametrize("p", [0.1, 0.2])
1927
@pytest.mark.parametrize("block_size", [3])
1928
@pytest.mark.parametrize("inplace", [False])
1929
def test_drop_block_random(self, seed, dim, p, block_size, inplace):
1930
torch.manual_seed(seed)
1937
x = torch.ones(size=(batch_size, channels, height, width))
1938
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
1940
x = torch.ones(size=(batch_size, channels, depth, height, width))
1941
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
1946
cell_numel = torch.tensor(x.shape).prod()
1947
for _ in range(trials):
1948
with torch.no_grad():
1950
non_zero_count = out.nonzero().size(0)
1951
counts += cell_numel - non_zero_count
1952
num_samples += cell_numel
1954
assert abs(p - counts / num_samples) / p < 0.15
1956
def make_obj(self, dim, p, block_size, inplace, wrap=False):
1958
obj = ops.DropBlock2d(p, block_size, inplace)
1960
obj = ops.DropBlock3d(p, block_size, inplace)
1961
return DropBlockWrapper(obj) if wrap else obj
1963
@pytest.mark.parametrize("dim", (2, 3))
1964
@pytest.mark.parametrize("p", [0, 1])
1965
@pytest.mark.parametrize("block_size", [5, 7])
1966
@pytest.mark.parametrize("inplace", [True, False])
1967
def test_is_leaf_node(self, dim, p, block_size, inplace):
1968
op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
1969
graph_node_names = get_graph_node_names(op_obj)
1971
assert len(graph_node_names) == 2
1972
assert len(graph_node_names[0]) == len(graph_node_names[1])
1973
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
1976
if __name__ == "__main__":
1977
pytest.main([__file__])