intel-extension-for-pytorch

Форк
0
5426 строк · 196.0 Кб
1
"""
2
From PyTorch:
3

4
Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)
5
Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)
6
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
7
Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)
8
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
9
Copyright (c) 2011-2013 NYU                      (Clement Farabet)
10
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
11
Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
12
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
13

14
From Caffe2:
15

16
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
17

18
All contributions by Facebook:
19
Copyright (c) 2016 Facebook Inc.
20

21
All contributions by Google:
22
Copyright (c) 2015 Google Inc.
23
All rights reserved.
24

25
All contributions by Yangqing Jia:
26
Copyright (c) 2015 Yangqing Jia
27
All rights reserved.
28

29
All contributions from Caffe:
30
Copyright(c) 2013, 2014, 2015, the respective contributors
31
All rights reserved.
32

33
All other contributions:
34
Copyright(c) 2015, 2016 the respective contributors
35
All rights reserved.
36

37
Caffe2 uses a copyright model similar to Caffe: each contributor holds
38
copyright over their contributions to Caffe2. The project versioning records
39
all such contribution and copyright details. If a contributor wants to further
40
mark their specific copyright on a particular contribution, they should
41
indicate their copyright solely in the commit message of the change when it is
42
committed.
43

44
All rights reserved.
45
"""
46

47
import math
48
import sys
49
import tempfile
50
import unittest
51

52
from copy import deepcopy
53
from functools import reduce
54
from itertools import product
55
from operator import mul
56
from math import pi
57

58

59
import torch
60
import torch.cuda
61
import torch.nn as nn
62
import torch.nn.functional as F
63
from torch.nn.functional import _Reduction
64
from common_utils import (
65
    TestCase,
66
    to_gpu,
67
    freeze_rng_state,
68
    is_iterable,
69
    TEST_WITH_ROCM,
70
    _assertGradAndGradgradChecks,
71
)
72
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
73
from torch.autograd import Variable
74
import torch.backends.cudnn
75

76
TEST_CUDA = torch.cuda.is_available()
77

78
# tarfile module tries to obtain a file object name in python 3.3
79
if sys.version_info[:2] == (3, 3):
80
    TemporaryFile = tempfile.NamedTemporaryFile
81
else:
82
    TemporaryFile = tempfile.TemporaryFile
83
PRECISION = 1e-5
84

85

86
def get_reduction(m):
87
    result = getattr(m, "reduction", None)
88
    if result is None:
89
        result = _Reduction.legacy_get_string(
90
            getattr(m, "sizeAverage", None), True, emit_warning=False
91
        )
92
    assert result is not None
93
    return result
94

95

96
def get_weight(m):
97
    result = getattr(m, "weight", None)
98
    if result is not None:
99
        return result
100
    return getattr(m, "weights", None)
101

102

103
# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
104
#
105
# The way to check API parity is to add parity tests for the NN module / functional of interest.
106
# Here are the detailed steps:
107
#
108
# For NN module:
109
# 1. Make sure you already have a test dict with the module configuration you want to test.
110
# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
111
#    the Python module constructor arguments. For example, if in the test dict we pass
112
#    `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
113
#    as the corresponding C++ constructor argument to `torch::nn::Linear`.
114
# 3. If in the process of performing the above step you referenced any variables
115
#    in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
116
#    to the test dict to make sure that those variables are populated with the right Python values.
117
#    For example, if the Python constructor call is
118
#    `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
119
#    the corresponding C++ constructor argument is
120
#    `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
121
#    and the `cpp_var_map` entry must be
122
#    `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
123
#    used in the C++ constructor argument with the Python tensor value `random_samples`.
124
#
125
# For NN functional:
126
# 1. Make sure you already have a test dict with the functional configuration you want to test.
127
# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
128
#    then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
129
#    functional optional arguments. For example, if the test dict's `constructor` entry is
130
#    `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
131
#    then the `cpp_options_args` entry should be
132
#    "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)".
133
# 3. Otherwise, if the test dict's `constructor` entry looks like
134
#    `wrap_functional(lambda i: F.some_functional_name(...))`,
135
#    then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
136
#    functional function call. For example, if the test dict's `constructor` entry is
137
#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
138
#    then the `cpp_function_call` entry should be
139
#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
140
# 4. If in the process of performing the above two steps you referenced any variables
141
#    in the `cpp_options_args` or `cpp_function_call` entry, you must
142
#    add `cpp_var_map` entry to the test dict to make sure that those variables
143
#    are populated with the right Python values. For example, if the test dict's `constructor` entry is
144
#    `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
145
#    then the `cpp_function_call` entry should be
146
#    "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
147
#    Notice that there are two variables `i` and `t` that need to have their values provided,
148
#    and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
149
#    (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
150
#    and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
151
#
152
# There are also a few optional flags in the test dict to control the C++ parity test behavior:
153
#
154
# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
155
# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
156

157
module_tests = [
158
    dict(
159
        module_name="Linear",
160
        constructor_args=(10, 8),
161
        cpp_constructor_args="torch::nn::LinearOptions(10, 8)",
162
        input_size=(4, 10),
163
        reference_fn=lambda i, p, _: torch.mm(i, p[0].t())
164
        + p[1].view(1, -1).expand(4, 8),
165
        check_gradgrad=False,
166
    ),
167
    dict(
168
        module_name="Linear",
169
        constructor_args=(10, 8, False),
170
        cpp_constructor_args="torch::nn::LinearOptions(10, 8).bias(false)",
171
        input_size=(4, 10),
172
        desc="no_bias",
173
        reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
174
        check_gradgrad=False,
175
    ),
176
    dict(
177
        module_name="Threshold",
178
        constructor_args=(2.0, 1.0),
179
        cpp_constructor_args="torch::nn::ThresholdOptions(2., 1.)",
180
        input_size=(2, 3, 4, 5),
181
        check_inplace=True,
182
        desc="threshold_value",
183
    ),
184
    dict(
185
        module_name="Threshold",
186
        constructor_args=(2.0, 10.0),
187
        cpp_constructor_args="torch::nn::ThresholdOptions(2., 10.)",
188
        input_size=(2, 3, 4, 5),
189
        desc="large_value",
190
    ),
191
    dict(
192
        module_name="ReLU",
193
        input_size=(2, 3, 4, 5),
194
        check_inplace=True,
195
    ),
196
    dict(
197
        module_name="ReLU6",
198
        input_size=(2, 3, 4, 5),
199
        check_inplace=True,
200
    ),
201
    dict(
202
        module_name="RReLU",
203
        input_size=(1, 2, 2),
204
        test_cuda=False,
205
    ),
206
    dict(
207
        module_name="RReLU",
208
        constructor_args=(0.1, 0.9),
209
        cpp_constructor_args="torch::nn::RReLUOptions().lower(0.1).upper(0.9)",
210
        input_size=(4, 4, 5),
211
        desc="with_up_down",
212
        test_cuda=False,
213
    ),
214
    dict(
215
        module_name="Hardtanh",
216
        input_size=(3, 2, 5),
217
        reference_fn=lambda i, *_: i.clamp(-1, 1),
218
    ),
219
    dict(
220
        module_name="Sigmoid",
221
        input_size=(2, 3, 4, 5),
222
    ),
223
    dict(
224
        module_name="Tanh",
225
        input_size=(2, 3, 4, 5),
226
    ),
227
    dict(
228
        module_name="Flatten",
229
        input_size=(2, 3, 4, 5),
230
        reference_fn=lambda i, *_: torch.flatten(i, 1),
231
    ),
232
    dict(
233
        module_name="Softmax",
234
        constructor_args=(1,),
235
        cpp_constructor_args="torch::nn::SoftmaxOptions(1)",
236
        input_size=(10, 20),
237
        reference_fn=lambda i, *_: torch.exp(i).div(
238
            torch.exp(i).sum(1, True).expand(10, 20)
239
        ),
240
    ),
241
    dict(
242
        module_name="Softmax2d",
243
        input_size=(1, 3, 10, 20),
244
        reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(1, False)),
245
    ),
246
    dict(
247
        module_name="LogSoftmax",
248
        constructor_args=(1,),
249
        cpp_constructor_args="torch::nn::LogSoftmaxOptions(1)",
250
        input_size=(10, 20),
251
        reference_fn=lambda i, *_: torch.exp(i)
252
        .div_(torch.exp(i).sum(1, True).expand(10, 20))
253
        .log_(),
254
    ),
255
    dict(
256
        module_name="LogSoftmax",
257
        constructor_args=(1,),
258
        cpp_constructor_args="torch::nn::LogSoftmaxOptions(1)",
259
        input_size=(1, 3, 10, 20),
260
        reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
261
        desc="multiparam",
262
    ),
263
    dict(
264
        module_name="ELU",
265
        constructor_args=(2.0,),
266
        cpp_constructor_args="torch::nn::ELUOptions().alpha(2.)",
267
        input_size=(3, 2, 5),
268
        reference_fn=lambda x, *_: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
269
    ),
270
    # TODO: reference function
271
    dict(
272
        module_name="Hardshrink",
273
        constructor_args=(2.0,),
274
        cpp_constructor_args="torch::nn::HardshrinkOptions(2.)",
275
        input_size=(4, 3, 2, 4),
276
    ),
277
    dict(module_name="LeakyReLU", input_size=(3, 2, 5), check_inplace=True),
278
    dict(
279
        module_name="LeakyReLU",
280
        constructor_args=(0.5,),
281
        cpp_constructor_args="torch::nn::LeakyReLUOptions().negative_slope(0.5)",
282
        input_size=(3, 2, 5),
283
        check_inplace=True,
284
        desc="with_negval",
285
    ),
286
    dict(
287
        module_name="LogSigmoid",
288
        input_size=(2, 3, 4),
289
        reference_fn=lambda i, *_: i.sigmoid().log(),
290
    ),
291
    dict(
292
        module_name="Softplus",
293
        input_size=(10, 20),
294
        reference_fn=lambda i, *_: torch.log(1 + torch.exp(i)),
295
    ),
296
    dict(
297
        module_name="Softplus",
298
        constructor_args=(2,),
299
        cpp_constructor_args="torch::nn::SoftplusOptions().beta(2)",
300
        input_size=(10, 20),
301
        reference_fn=lambda i, *_: 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i)),
302
        desc="beta",
303
    ),
304
    dict(
305
        module_name="Softplus",
306
        constructor_args=(2, -100),
307
        cpp_constructor_args="torch::nn::SoftplusOptions().beta(2).threshold(-100)",
308
        input_size=(10, 20),
309
        reference_fn=(
310
            lambda i, *_: ((i * 2) > -100).type_as(i) * i
311
            + ((i * 2) <= -100).type_as(i) * 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i))
312
        ),
313
        desc="beta_threshold",
314
    ),
315
    dict(
316
        module_name="Softshrink",
317
        input_size=(3, 2, 5),
318
    ),
319
    dict(
320
        module_name="Softshrink",
321
        constructor_args=(1,),
322
        cpp_constructor_args="torch::nn::SoftshrinkOptions(1)",
323
        input_size=(3, 2, 5),
324
        desc="lambda",
325
    ),
326
    dict(
327
        module_name="CrossMapLRN2d",
328
        constructor_args=(5, 5e-3, 1e-3, 2),
329
        cpp_constructor_args="torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)",
330
        input_size=(2, 3, 6, 6),
331
        check_gradgrad=False,
332
    ),
333
    dict(
334
        module_name="PReLU",
335
        input_size=(2, 3, 4),
336
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
337
        + torch.clamp(i, max=0) * p[0][0],
338
        desc="1d",
339
    ),
340
    dict(
341
        module_name="PReLU",
342
        constructor_args=(3,),
343
        cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
344
        input_size=(2, 3, 4),
345
        desc="1d_multiparam",
346
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
347
        + torch.clamp(i, max=0) * p[0][0],
348
    ),
349
    dict(
350
        module_name="PReLU",
351
        input_size=(2, 3, 4, 5),
352
        desc="2d",
353
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
354
        + torch.clamp(i, max=0) * p[0][0],
355
    ),
356
    dict(
357
        module_name="PReLU",
358
        constructor_args=(3,),
359
        cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
360
        input_size=(2, 3, 4, 5),
361
        desc="2d_multiparam",
362
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
363
        + torch.clamp(i, max=0) * p[0][0],
364
    ),
365
    dict(
366
        module_name="PReLU",
367
        input_size=(2, 3, 4, 5, 6),
368
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
369
        + torch.clamp(i, max=0) * p[0][0],
370
        desc="3d",
371
    ),
372
    dict(
373
        module_name="PReLU",
374
        constructor_args=(3,),
375
        cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
376
        input_size=(2, 3, 4, 5, 6),
377
        desc="3d_multiparam",
378
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
379
        + torch.clamp(i, max=0) * p[0][0],
380
    ),
381
    dict(
382
        module_name="Softsign",
383
        input_size=(3, 2, 5),
384
        reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
385
    ),
386
    dict(
387
        module_name="Softmin",
388
        constructor_args=(1,),
389
        cpp_constructor_args="torch::nn::SoftminOptions(1)",
390
        input_size=(10, 20),
391
    ),
392
    dict(
393
        module_name="Softmin",
394
        constructor_args=(1,),
395
        cpp_constructor_args="torch::nn::SoftminOptions(1)",
396
        input_size=(2, 3, 5, 10),
397
        desc="multidim",
398
    ),
399
    dict(
400
        module_name="Tanhshrink",
401
        input_size=(2, 3, 4, 5),
402
    ),
403
]
404

405

406
# Generates rand tensor with non-equal values. This ensures that duplicate
407
# values won't be causing test failure for modules like MaxPooling.
408
# size should be small, otherwise randperm fails / long overflows.
409
def _rand_tensor_non_equal(*size):
410
    total = reduce(mul, size, 1)
411
    return torch.randperm(total).view(*size).double()
412

413

414
def wrap_functional(fn, **kwargs):
415
    class FunctionalModule(nn.Module):
416
        def forward(self, *args):
417
            return fn(*args, **kwargs)
418

419
    return FunctionalModule
420

421

422
def poissonnllloss_no_reduce_test():
423
    t = torch.randn(10, 10)
424
    return dict(
425
        fullname="PoissonNLLLoss_no_reduce",
426
        constructor=wrap_functional(
427
            lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction="none")
428
        ),
429
        cpp_function_call="F::poisson_nll_loss("
430
        "i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))",
431
        input_fn=lambda: torch.rand(10, 10),
432
        cpp_var_map={"i": "_get_input()", "t": t},
433
        reference_fn=lambda i, *_: i.exp() - t.mul(i),
434
        pickle=False,
435
    )
436

437

438
def bceloss_no_reduce_test():
439
    t = Variable(torch.randn(15, 10).gt(0).double())
440
    return dict(
441
        fullname="BCELoss_no_reduce",
442
        constructor=wrap_functional(
443
            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction="none")
444
        ),
445
        cpp_function_call="F::binary_cross_entropy("
446
        "i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
447
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
448
        cpp_var_map={"i": "_get_input()", "t": t},
449
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
450
        pickle=False,
451
        precision=7e-4,
452
    )
453

454

455
def bceloss_no_reduce_scalar_test():
456
    t = torch.randn(()).gt(0).double()
457
    return dict(
458
        fullname="BCELoss_no_reduce_scalar",
459
        constructor=wrap_functional(
460
            lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction="none")
461
        ),
462
        cpp_function_call="F::binary_cross_entropy("
463
        "i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
464
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
465
        cpp_var_map={"i": "_get_input()", "t": t},
466
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
467
        pickle=False,
468
    )
469

470

471
def bceloss_weights_no_reduce_test():
472
    t = Variable(torch.randn(15, 10).gt(0).double())
473
    weights = torch.rand(10)
474
    return dict(
475
        fullname="BCELoss_weights_no_reduce",
476
        constructor=wrap_functional(
477
            lambda i: F.binary_cross_entropy(
478
                i, t.type_as(i), weight=weights.type_as(i), reduction="none"
479
            )
480
        ),
481
        cpp_function_call="F::binary_cross_entropy("
482
        "i, t.to(i.options()), "
483
        "F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))",
484
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
485
        cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
486
        reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
487
        pickle=False,
488
        precision=3e-4,
489
    )
490

491

492
def bceloss_weights_no_reduce_scalar_test():
493
    t = torch.randn(()).double()
494
    weights = torch.rand(())
495
    return dict(
496
        fullname="BCELoss_weights_no_reduce_scalar",
497
        constructor=wrap_functional(
498
            lambda i: F.binary_cross_entropy(
499
                i, t.type_as(i), weight=weights.type_as(i), reduction="none"
500
            )
501
        ),
502
        cpp_function_call="""F::binary_cross_entropy(
503
            i, t.to(i.options()),
504
            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
505
        cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
506
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
507
        reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
508
        pickle=False,
509
    )
510

511

512
def bce_with_logistic_legacy_enum_test():
513
    t = Variable(torch.randn(15, 10).gt(0).double())
514
    sigmoid = nn.Sigmoid()
515
    return dict(
516
        fullname="BCEWithLogitsLoss_legacy_enum",
517
        constructor=wrap_functional(
518
            lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)
519
        ),
520
        cpp_function_call="""F::binary_cross_entropy_with_logits(
521
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
522
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
523
        cpp_var_map={"i": "_get_input()", "t": t},
524
        reference_fn=lambda i, *_: -(
525
            t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
526
        ),
527
        check_gradgrad=False,
528
        pickle=False,
529
    )
530

531

532
def bce_with_logistic_no_reduce_test():
533
    t = Variable(torch.randn(15, 10).gt(0).double())
534
    sigmoid = nn.Sigmoid()
535
    return dict(
536
        fullname="BCEWithLogitsLoss_no_reduce",
537
        constructor=wrap_functional(
538
            lambda i: F.binary_cross_entropy_with_logits(
539
                i, t.type_as(i), reduction="none"
540
            )
541
        ),
542
        cpp_function_call="""F::binary_cross_entropy_with_logits(
543
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
544
        input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
545
        cpp_var_map={"i": "_get_input()", "t": t},
546
        reference_fn=lambda i, *_: -(
547
            t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
548
        ),
549
        check_gradgrad=False,
550
        pickle=False,
551
    )
552

553

554
def bce_with_logistic_no_reduce_scalar_test():
555
    t = torch.randn(()).gt(0).double()
556
    sigmoid = nn.Sigmoid()
557
    return dict(
558
        fullname="BCEWithLogitsLoss_no_reduce_scalar",
559
        constructor=wrap_functional(
560
            lambda i: F.binary_cross_entropy_with_logits(
561
                i, t.type_as(i), reduction="none"
562
            )
563
        ),
564
        cpp_function_call="""F::binary_cross_entropy_with_logits(
565
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
566
        input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
567
        cpp_var_map={"i": "_get_input()", "t": t},
568
        reference_fn=lambda i, *_: -(
569
            t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
570
        ),
571
        check_gradgrad=False,
572
        pickle=False,
573
    )
574

575

576
def kldivloss_with_target_no_reduce_test():
577
    i = torch.rand(10, 10).log()
578

579
    return dict(
580
        fullname="KLDivLoss_with_target_no_reduce",
581
        constructor=wrap_functional(
582
            lambda t: F.kl_div(i.type_as(t), t, reduction="none")
583
        ),
584
        cpp_function_call="F::kl_div(i.to(t.options()), t, F::KLDivFuncOptions().reduction(torch::kNone))",
585
        input_fn=lambda: torch.rand(10, 10),
586
        cpp_var_map={"i": i, "t": "_get_input()"},
587
        reference_fn=lambda t, *_: loss_reference_fns["KLDivLoss"](
588
            i.type_as(t), t, reduction="none"
589
        ),
590
        pickle=False,
591
    )
592

593

594
def kldivloss_no_reduce_test():
595
    t = torch.randn(10, 10)
596
    return dict(
597
        fullname="KLDivLoss_no_reduce",
598
        constructor=wrap_functional(
599
            lambda i: F.kl_div(i, t.type_as(i), reduction="none")
600
        ),
601
        cpp_function_call="F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))",
602
        input_fn=lambda: torch.rand(10, 10).log(),
603
        cpp_var_map={"i": "_get_input()", "t": t},
604
        reference_fn=lambda i, *_: loss_reference_fns["KLDivLoss"](
605
            i, t.type_as(i), reduction="none"
606
        ),
607
        pickle=False,
608
    )
609

610

611
def kldivloss_no_reduce_scalar_test():
612
    t = torch.randn(())
613
    return dict(
614
        fullname="KLDivLoss_no_reduce_scalar",
615
        constructor=wrap_functional(
616
            lambda i: F.kl_div(i, t.type_as(i), reduction="none")
617
        ),
618
        cpp_function_call="F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))",
619
        input_fn=lambda: torch.rand(()).log(),
620
        cpp_var_map={"i": "_get_input()", "t": t},
621
        reference_fn=lambda i, *_: loss_reference_fns["KLDivLoss"](
622
            i, t.type_as(i), reduction="none"
623
        ),
624
        pickle=False,
625
    )
626

627

628
def l1loss_no_reduce_test():
629
    t = torch.randn(2, 3, 4)
630
    return dict(
631
        fullname="L1Loss_no_reduce",
632
        constructor=wrap_functional(
633
            lambda i: F.l1_loss(i, t.type_as(i), reduction="none")
634
        ),
635
        cpp_function_call="F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))",
636
        input_fn=lambda: torch.randn(2, 3, 4),
637
        cpp_var_map={"i": "_get_input()", "t": t},
638
        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
639
        pickle=False,
640
    )
641

642

643
def l1loss_no_reduce_scalar_test():
644
    t = torch.randn(())
645
    return dict(
646
        fullname="L1Loss_no_reduce_scalar",
647
        constructor=wrap_functional(
648
            lambda i: F.l1_loss(i, t.type_as(i), reduction="none")
649
        ),
650
        cpp_function_call="F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))",
651
        input_fn=lambda: torch.randn(()),
652
        cpp_var_map={"i": "_get_input()", "t": t},
653
        reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
654
        pickle=False,
655
    )
656

657

658
def mseloss_no_reduce_test():
659
    input_size = (2, 3, 4, 5)
660
    target = torch.randn(*input_size)
661
    return dict(
662
        fullname="MSELoss_no_reduce",
663
        constructor=wrap_functional(
664
            lambda i: F.mse_loss(i, target.type_as(i), reduction="none")
665
        ),
666
        cpp_function_call="F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))",
667
        input_size=input_size,
668
        cpp_var_map={"i": "_get_input()", "target": target},
669
        reference_fn=lambda i, *_: (i - target).pow(2),
670
        pickle=False,
671
    )
672

673

674
def mseloss_no_reduce_scalar_test():
675
    input_size = ()
676
    target = torch.randn(input_size)
677
    return dict(
678
        fullname="MSELoss_no_reduce_scalar",
679
        constructor=wrap_functional(
680
            lambda i: F.mse_loss(i, target.type_as(i), reduction="none")
681
        ),
682
        cpp_function_call="F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))",
683
        input_size=input_size,
684
        cpp_var_map={"i": "_get_input()", "target": target},
685
        reference_fn=lambda i, *_: (i - target).pow(2),
686
        pickle=False,
687
    )
688

689

690
def nllloss_no_reduce_test():
691
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
692
    kwargs = {"reduction": "none"}
693
    return dict(
694
        fullname="NLLLoss_no_reduce",
695
        constructor=wrap_functional(
696
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
697
        ),
698
        cpp_function_call="""F::nll_loss(
699
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
700
        input_fn=lambda: torch.rand(15, 10).log(),
701
        cpp_var_map={"i": "_get_input()", "t": t},
702
        reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
703
            i, t.type_as(i).long(), **kwargs
704
        ),
705
        pickle=False,
706
    )
707

708

709
def nllloss_no_reduce_ignore_index_test():
710
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
711
    kwargs = {"ignore_index": 2, "reduction": "none"}
712
    return dict(
713
        fullname="NLLLoss_no_reduce_ignore_index",
714
        constructor=wrap_functional(
715
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
716
        ),
717
        cpp_function_call="""F::nll_loss(
718
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))""",
719
        input_fn=lambda: torch.rand(15, 10).log(),
720
        cpp_var_map={"i": "_get_input()", "t": t},
721
        reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
722
            i, t.type_as(i).long(), **kwargs
723
        ),
724
        pickle=False,
725
    )
726

727

728
def nllloss_no_reduce_weights_test():
729
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
730
    weight = torch.rand(10)
731

732
    def kwargs(i):
733
        return {"weight": weight.type_as(i), "reduction": "none"}
734

735
    return dict(
736
        fullname="NLLLoss_no_reduce_weights",
737
        constructor=wrap_functional(
738
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
739
        ),
740
        cpp_function_call="""F::nll_loss(
741
            i, t.to(i.options()).to(torch::kLong),
742
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
743
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
744
        cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
745
        reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
746
            i, t.type_as(i).long(), **kwargs(i)
747
        ),
748
        pickle=False,
749
    )
750

751

752
def nllloss_no_reduce_weights_ignore_index_test():
753
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
754
    weight = torch.rand(10)
755

756
    def kwargs(i):
757
        return {"weight": weight.type_as(i), "reduction": "none", "ignore_index": 2}
758

759
    return dict(
760
        fullname="NLLLoss_no_reduce_weights_ignore_index",
761
        constructor=wrap_functional(
762
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))
763
        ),
764
        cpp_function_call="""F::nll_loss(
765
            i, t.to(i.options()).to(torch::kLong),
766
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))""",
767
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
768
        cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
769
        reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
770
            i, t.type_as(i).long(), **kwargs(i)
771
        ),
772
        pickle=False,
773
    )
774

775

776
def nllloss_no_reduce_weights_ignore_index_neg_test():
777
    t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
778
    weight = torch.rand(10)
779

780
    def kwargs(i):
781
        return {"weight": weight.type_as(i), "reduction": "none", "ignore_index": -1}
782

783
    return dict(
784
        fullname="NLLLoss_no_reduce_weights_ignore_index_neg",
785
        constructor=wrap_functional(
786
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
787
        ),
788
        cpp_function_call="""F::nll_loss(
789
            i, t.to(i.options()).to(torch::kLong),
790
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))""",
791
        input=torch.rand(15, 10).add(1e-2).log(),
792
        cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
793
        reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
794
            i, t.type_as(i).long(), **kwargs(i)
795
        ),
796
        pickle=False,
797
    )
798

799

800
def nllloss2d_no_reduce_test():
801
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
802
    kwargs = {"reduction": "none"}
803
    return dict(
804
        fullname="NLLLoss2d_no_reduce",
805
        constructor=wrap_functional(
806
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
807
        ),
808
        cpp_function_call="""F::nll_loss(
809
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
810
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
811
        cpp_var_map={"i": "_get_input()", "t": t},
812
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
813
            i, t.type_as(i).long(), **kwargs
814
        ),
815
        pickle=False,
816
    )
817

818

819
def nllloss2d_no_reduce_ignore_index_test():
820
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
821
    kwargs = {"ignore_index": 1, "reduction": "none"}
822
    return dict(
823
        fullname="NLLLoss2d_no_reduce_ignore_index",
824
        constructor=wrap_functional(
825
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
826
        ),
827
        cpp_function_call="""F::nll_loss(
828
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))""",
829
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
830
        cpp_var_map={"i": "_get_input()", "t": t},
831
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
832
            i, t.type_as(i).long(), **kwargs
833
        ),
834
        pickle=False,
835
    )
836

837

838
def nllloss2d_no_reduce_weights_test():
839
    t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
840
    weight = torch.rand(3)
841

842
    def kwargs(i):
843
        return {"weight": weight.type_as(i), "reduction": "none"}
844

845
    return dict(
846
        fullname="NLLLoss2d_no_reduce_weights",
847
        constructor=wrap_functional(
848
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
849
        ),
850
        cpp_function_call="""F::nll_loss(
851
            i, t.to(i.options()).to(torch::kLong),
852
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
853
        input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
854
        cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
855
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
856
            i, t.type_as(i).long(), **kwargs(i)
857
        ),
858
        pickle=False,
859
    )
860

861

862
def nlllossNd_no_reduce_test():
863
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
864
    kwargs = {"reduction": "none"}
865
    return dict(
866
        fullname="NLLLossNd_no_reduce",
867
        constructor=wrap_functional(
868
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
869
        ),
870
        cpp_function_call="""F::nll_loss(
871
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
872
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
873
        cpp_var_map={"i": "_get_input()", "t": t},
874
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
875
            i, t.type_as(i).long(), **kwargs
876
        ),
877
        pickle=False,
878
    )
879

880

881
def nlllossNd_no_reduce_ignore_index_test():
882
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
883
    kwargs = {"ignore_index": 1, "reduction": "none"}
884
    return dict(
885
        fullname="NLLLossNd_no_reduce_ignore_index",
886
        constructor=wrap_functional(
887
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
888
        ),
889
        cpp_function_call="""F::nll_loss(
890
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))""",
891
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
892
        cpp_var_map={"i": "_get_input()", "t": t},
893
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
894
            i, t.type_as(i).long(), **kwargs
895
        ),
896
        pickle=False,
897
    )
898

899

900
def nlllossNd_no_reduce_weights_test():
901
    t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
902
    weight = torch.rand(3)
903

904
    def kwargs(i):
905
        return {"weight": weight.type_as(i), "reduction": "none"}
906

907
    return dict(
908
        fullname="NLLLossNd_no_reduce_weights",
909
        constructor=wrap_functional(
910
            lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
911
        ),
912
        cpp_function_call="""F::nll_loss(
913
            i, t.to(i.options()).to(torch::kLong),
914
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
915
        input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
916
        cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
917
        reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
918
            i, t.type_as(i).long(), **kwargs(i)
919
        ),
920
        pickle=False,
921
    )
922

923

924
def smoothl1loss_no_reduce_test():
925
    t = torch.randn(2, 3, 4)
926
    return dict(
927
        fullname="SmoothL1Loss_no_reduce",
928
        constructor=wrap_functional(
929
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction="none")
930
        ),
931
        cpp_function_call="""F::smooth_l1_loss(
932
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))""",
933
        input_fn=lambda: torch.randn(2, 3, 4),
934
        cpp_var_map={"i": "_get_input()", "t": t},
935
        reference_fn=lambda i, *_: loss_reference_fns["SmoothL1Loss"](
936
            i, t.type_as(i), reduction="none"
937
        ),
938
        pickle=False,
939
    )
940

941

942
def smoothl1loss_no_reduce_scalar_test():
943
    t = torch.randn(())
944
    return dict(
945
        fullname="SmoothL1Loss_no_reduce_scalar",
946
        constructor=wrap_functional(
947
            lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction="none")
948
        ),
949
        cpp_function_call="""F::smooth_l1_loss(
950
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))""",
951
        input_fn=lambda: torch.randn(()),
952
        cpp_var_map={"i": "_get_input()", "t": t},
953
        reference_fn=lambda i, *_: loss_reference_fns["SmoothL1Loss"](
954
            i, t.type_as(i), reduction="none"
955
        ),
956
        pickle=False,
957
    )
958

959

960
def multilabelmarginloss_0d_no_reduce_test():
961
    t = torch.zeros(()).long()
962
    return dict(
963
        fullname="MultiLabelMarginLoss_0d_no_reduce",
964
        constructor=wrap_functional(
965
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
966
        ),
967
        cpp_function_call="""F::multilabel_margin_loss(
968
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
969
        input_fn=lambda: torch.randn(()),
970
        cpp_var_map={"i": "_get_input()", "t": t},
971
        reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
972
            i, t.data.type_as(i).long(), reduction="none"
973
        ),
974
        check_sum_reduction=True,
975
        check_gradgrad=False,
976
        pickle=False,
977
    )
978

979

980
def multilabelmarginloss_1d_no_reduce_test():
981
    t = Variable(torch.rand(10).mul(10).floor().long())
982
    return dict(
983
        fullname="MultiLabelMarginLoss_1d_no_reduce",
984
        constructor=wrap_functional(
985
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
986
        ),
987
        cpp_function_call="""F::multilabel_margin_loss(
988
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
989
        input_fn=lambda: torch.randn(10),
990
        cpp_var_map={"i": "_get_input()", "t": t},
991
        reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
992
            i, t.data.type_as(i).long(), reduction="none"
993
        ),
994
        check_sum_reduction=True,
995
        check_gradgrad=False,
996
        pickle=False,
997
    )
998

999

1000
def multilabelmarginloss_index_neg_test():
1001
    t = Variable(
1002
        torch.clamp(torch.rand(5, 10).add(-0.5).mul(20).floor().long(), min=-1)
1003
    )
1004
    return dict(
1005
        fullname="MultiLabelMarginLoss_index_neg",
1006
        constructor=wrap_functional(
1007
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
1008
        ),
1009
        cpp_function_call="""F::multilabel_margin_loss(
1010
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
1011
        input_fn=lambda: torch.randn(5, 10),
1012
        cpp_var_map={"i": "_get_input()", "t": t},
1013
        reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
1014
            i, t.data.type_as(i).long(), reduction="none"
1015
        ),
1016
        check_sum_reduction=True,
1017
        check_gradgrad=False,
1018
        pickle=False,
1019
    )
1020

1021

1022
def multilabelmarginloss_no_reduce_test():
1023
    t = Variable(torch.rand(5, 10).mul(10).floor().long())
1024
    return dict(
1025
        fullname="MultiLabelMarginLoss_no_reduce",
1026
        constructor=wrap_functional(
1027
            lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
1028
        ),
1029
        cpp_function_call="""F::multilabel_margin_loss(
1030
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
1031
        input_fn=lambda: torch.randn(5, 10),
1032
        cpp_var_map={"i": "_get_input()", "t": t},
1033
        reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
1034
            i, t.data.type_as(i).long(), reduction="none"
1035
        ),
1036
        check_sum_reduction=True,
1037
        check_gradgrad=False,
1038
        pickle=False,
1039
    )
1040

1041

1042
def hingeembeddingloss_no_reduce_test():
1043
    t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
1044
    return dict(
1045
        fullname="HingeEmbeddingLoss_no_reduce",
1046
        constructor=wrap_functional(
1047
            lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction="none")
1048
        ),
1049
        cpp_function_call="""F::hinge_embedding_loss(
1050
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))""",
1051
        input_fn=lambda: torch.randn(10),
1052
        cpp_var_map={"i": "_get_input()", "t": t},
1053
        reference_fn=lambda i, *_: loss_reference_fns["HingeEmbeddingLoss"](
1054
            i, t.type_as(i), reduction="none"
1055
        ),
1056
        check_sum_reduction=True,
1057
        pickle=False,
1058
    )
1059

1060

1061
def hingeembeddingloss_margin_no_reduce_test():
1062
    t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
1063
    return dict(
1064
        fullname="HingeEmbeddingLoss_margin_no_reduce",
1065
        constructor=wrap_functional(
1066
            lambda i: F.hinge_embedding_loss(
1067
                i, t.type_as(i), margin=0.5, reduction="none"
1068
            )
1069
        ),
1070
        cpp_function_call="""F::hinge_embedding_loss(
1071
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))""",
1072
        input_fn=lambda: torch.randn(10),
1073
        cpp_var_map={"i": "_get_input()", "t": t},
1074
        reference_fn=lambda i, *_: loss_reference_fns["HingeEmbeddingLoss"](
1075
            i, t.type_as(i), margin=0.5, reduction="none"
1076
        ),
1077
        check_sum_reduction=True,
1078
        pickle=False,
1079
    )
1080

1081

1082
def softmarginloss_no_reduce_test():
1083
    t = torch.randn(5, 5)
1084
    return dict(
1085
        fullname="SoftMarginLoss_no_reduce",
1086
        constructor=wrap_functional(
1087
            lambda i: F.soft_margin_loss(i, t.type_as(i), reduction="none")
1088
        ),
1089
        cpp_function_call="""F::soft_margin_loss(
1090
            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))""",
1091
        input_fn=lambda: torch.randn(5, 5),
1092
        cpp_var_map={"i": "_get_input()", "t": t},
1093
        reference_fn=lambda i, *_: loss_reference_fns["SoftMarginLoss"](
1094
            i, t.type_as(i), reduction="none"
1095
        ),
1096
        pickle=False,
1097
    )
1098

1099

1100
def multilabelsoftmarginloss_no_reduce_test():
1101
    t = torch.rand(5, 10).mul(2).floor()
1102
    return dict(
1103
        fullname="MultiLabelSoftMarginLoss_no_reduce",
1104
        constructor=wrap_functional(
1105
            lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction="none")
1106
        ),
1107
        cpp_function_call="""F::multilabel_soft_margin_loss(
1108
            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))""",
1109
        input_fn=lambda: torch.randn(5, 10),
1110
        cpp_var_map={"i": "_get_input()", "t": t},
1111
        reference_fn=lambda i, *_: (
1112
            -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())
1113
        ).sum(dim=1)
1114
        / i.size(1),
1115
        check_gradgrad=False,
1116
        pickle=False,
1117
    )
1118

1119

1120
def multilabelsoftmarginloss_weights_no_reduce_test():
1121
    t = torch.rand(5, 10).mul(2).floor()
1122
    weights = torch.rand(10)
1123
    return dict(
1124
        fullname="MultiLabelSoftMarginLoss_weights_no_reduce",
1125
        constructor=wrap_functional(
1126
            lambda i: F.multilabel_soft_margin_loss(
1127
                i, t.type_as(i), weight=weights.type_as(i), reduction="none"
1128
            )
1129
        ),
1130
        cpp_function_call="""F::multilabel_soft_margin_loss(
1131
            i, t.to(i.options()),
1132
            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
1133
        input_fn=lambda: torch.randn(5, 10),
1134
        cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
1135
        reference_fn=lambda i, *_: (
1136
            -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights
1137
        ).sum(dim=1)
1138
        / i.size(1),
1139
        check_sum_reduction=True,
1140
        check_gradgrad=False,
1141
        pickle=False,
1142
    )
1143

1144

1145
def multimarginloss_no_reduce_test():
1146
    t = torch.rand(5).mul(8).floor().long()
1147
    return dict(
1148
        fullname="MultiMarginLoss_no_reduce",
1149
        constructor=wrap_functional(
1150
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1151
        ),
1152
        cpp_function_call="""F::multi_margin_loss(
1153
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1154
        input_fn=lambda: torch.randn(5, 10),
1155
        cpp_var_map={"i": "_get_input()", "t": t},
1156
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1157
            i, t.data.type_as(i).long(), reduction="none"
1158
        ),
1159
        check_sum_reduction=True,
1160
        check_gradgrad=False,
1161
        pickle=False,
1162
    )
1163

1164

1165
def multimarginloss_1d_no_reduce_test():
1166
    t = torch.rand(1).mul(8).floor().long()
1167
    return dict(
1168
        fullname="MultiMarginLoss_1d_no_reduce",
1169
        constructor=wrap_functional(
1170
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1171
        ),
1172
        cpp_function_call="""F::multi_margin_loss(
1173
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1174
        input_fn=lambda: torch.randn(10),
1175
        cpp_var_map={"i": "_get_input()", "t": t},
1176
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1177
            i, t.data.type_as(i).long(), reduction="none"
1178
        ),
1179
        check_sum_reduction=True,
1180
        check_gradgrad=False,
1181
        pickle=False,
1182
    )
1183

1184

1185
def multimarginloss_1d_input_0d_target_no_reduce_test():
1186
    t = torch.rand(()).mul(8).floor().long()
1187
    return dict(
1188
        fullname="multimarginloss_1d_input_0d_target_no_reduce",
1189
        constructor=wrap_functional(
1190
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1191
        ),
1192
        cpp_function_call="""F::multi_margin_loss(
1193
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1194
        input_fn=lambda: torch.randn(10),
1195
        cpp_var_map={"i": "_get_input()", "t": t},
1196
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1197
            i, t.data.type_as(i).long(), reduction="none"
1198
        ),
1199
        check_sum_reduction=True,
1200
        check_gradgrad=False,
1201
        pickle=False,
1202
    )
1203

1204

1205
def multimarginloss_p_no_reduce_test():
1206
    t = torch.rand(5).mul(8).floor().long()
1207
    return dict(
1208
        fullname="MultiMarginLoss_p_no_reduce",
1209
        constructor=wrap_functional(
1210
            lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction="none")
1211
        ),
1212
        cpp_function_call="""F::multi_margin_loss(
1213
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))""",
1214
        input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
1215
        cpp_var_map={"i": "_get_input()", "t": t},
1216
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1217
            i, t.data.type_as(i).long(), p=2, reduction="none"
1218
        ),
1219
        check_sum_reduction=True,
1220
        check_gradgrad=False,
1221
        pickle=False,
1222
    )
1223

1224

1225
def multimarginloss_margin_no_reduce_test():
1226
    t = torch.rand(5).mul(8).floor().long()
1227
    return dict(
1228
        fullname="MultiMarginLoss_margin_no_reduce",
1229
        constructor=wrap_functional(
1230
            lambda i: F.multi_margin_loss(
1231
                i, t.type_as(i).long(), margin=0.5, reduction="none"
1232
            )
1233
        ),
1234
        cpp_function_call="""F::multi_margin_loss(
1235
            i, t.to(i.options()).to(torch::kLong),
1236
            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))""",
1237
        input_fn=lambda: torch.randn(5, 10),
1238
        cpp_var_map={"i": "_get_input()", "t": t},
1239
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1240
            i, t.data.type_as(i).long(), margin=0.5, reduction="none"
1241
        ),
1242
        check_sum_reduction=True,
1243
        check_gradgrad=False,
1244
        pickle=False,
1245
    )
1246

1247

1248
def multimarginloss_weights_no_reduce_test():
1249
    t = torch.rand(5).mul(8).floor().long()
1250
    weights = torch.rand(10)
1251
    return dict(
1252
        fullname="MultiMarginLoss_weights_no_reduce",
1253
        constructor=wrap_functional(
1254
            lambda i: F.multi_margin_loss(
1255
                i, t.type_as(i).long(), weight=weights.type_as(i), reduction="none"
1256
            )
1257
        ),
1258
        cpp_function_call="""F::multi_margin_loss(
1259
            i, t.to(i.options()).to(torch::kLong),
1260
            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
1261
        input_fn=lambda: torch.randn(5, 10),
1262
        cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
1263
        reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1264
            i, t.data.type_as(i).long(), weight=weights, reduction="none"
1265
        ),
1266
        check_sum_reduction=True,
1267
        check_gradgrad=False,
1268
        pickle=False,
1269
    )
1270

1271

1272
def fractional_max_pool2d_test(test_case):
1273
    random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
1274
    if test_case == "ratio":
1275
        return dict(
1276
            constructor=lambda: nn.FractionalMaxPool2d(
1277
                2, output_ratio=0.5, _random_samples=random_samples
1278
            ),
1279
            cpp_constructor_args="""torch::nn::FractionalMaxPool2dOptions(2)
1280
                                    .output_ratio(0.5)
1281
                                    ._random_samples(random_samples)""",
1282
            input_size=(1, 3, 5, 7),
1283
            cpp_var_map={"random_samples": random_samples},
1284
            fullname="FractionalMaxPool2d_ratio",
1285
        )
1286
    elif test_case == "size":
1287
        return dict(
1288
            constructor=lambda: nn.FractionalMaxPool2d(
1289
                (2, 3), output_size=(4, 3), _random_samples=random_samples
1290
            ),
1291
            cpp_constructor_args="""torch::nn::FractionalMaxPool2dOptions({2, 3})
1292
                                    .output_size(std::vector<int64_t>({4, 3}))
1293
                                    ._random_samples(random_samples)""",
1294
            input_size=(1, 3, 7, 6),
1295
            cpp_var_map={"random_samples": random_samples},
1296
            fullname="FractionalMaxPool2d_size",
1297
        )
1298

1299

1300
def fractional_max_pool3d_test(test_case):
1301
    random_samples = torch.DoubleTensor(2, 4, 3).uniform_()
1302
    if test_case == "ratio":
1303
        return dict(
1304
            constructor=lambda: nn.FractionalMaxPool3d(
1305
                2, output_ratio=0.5, _random_samples=random_samples
1306
            ),
1307
            cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions(2)
1308
                                    .output_ratio(0.5)
1309
                                    ._random_samples(random_samples)""",
1310
            input_size=(2, 4, 5, 5, 5),
1311
            cpp_var_map={"random_samples": random_samples},
1312
            fullname="FractionalMaxPool3d_ratio",
1313
        )
1314
    elif test_case == "size":
1315
        return dict(
1316
            constructor=lambda: nn.FractionalMaxPool3d(
1317
                (2, 2, 2), output_size=(4, 4, 4), _random_samples=random_samples
1318
            ),
1319
            cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions({2, 2, 2})
1320
                                    .output_size(std::vector<int64_t>({4, 4, 4}))
1321
                                    ._random_samples(random_samples)""",
1322
            input_size=(2, 4, 7, 7, 7),
1323
            cpp_var_map={"random_samples": random_samples},
1324
            fullname="FractionalMaxPool3d_size",
1325
        )
1326
    elif test_case == "asymsize":
1327
        return dict(
1328
            constructor=lambda: nn.FractionalMaxPool3d(
1329
                (4, 2, 3), output_size=(10, 3, 2), _random_samples=random_samples
1330
            ),
1331
            cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions({4, 2, 3})
1332
                                    .output_size(std::vector<int64_t>({10, 3, 2}))
1333
                                    ._random_samples(random_samples)""",
1334
            input_size=(2, 4, 16, 7, 5),
1335
            cpp_var_map={"random_samples": random_samples},
1336
            fullname="FractionalMaxPool3d_asymsize",
1337
        )
1338

1339

1340
new_module_tests = [
1341
    poissonnllloss_no_reduce_test(),
1342
    bceloss_no_reduce_test(),
1343
    bceloss_weights_no_reduce_test(),
1344
    bce_with_logistic_legacy_enum_test(),
1345
    bce_with_logistic_no_reduce_test(),
1346
    bceloss_no_reduce_scalar_test(),
1347
    bceloss_weights_no_reduce_scalar_test(),
1348
    bce_with_logistic_no_reduce_scalar_test(),
1349
    kldivloss_with_target_no_reduce_test(),
1350
    kldivloss_no_reduce_test(),
1351
    kldivloss_no_reduce_scalar_test(),
1352
    l1loss_no_reduce_test(),
1353
    l1loss_no_reduce_scalar_test(),
1354
    mseloss_no_reduce_test(),
1355
    mseloss_no_reduce_scalar_test(),
1356
    nllloss_no_reduce_test(),
1357
    nllloss_no_reduce_ignore_index_test(),
1358
    nllloss_no_reduce_weights_test(),
1359
    nllloss_no_reduce_weights_ignore_index_test(),
1360
    nllloss_no_reduce_weights_ignore_index_neg_test(),
1361
    nllloss2d_no_reduce_test(),
1362
    nllloss2d_no_reduce_weights_test(),
1363
    nllloss2d_no_reduce_ignore_index_test(),
1364
    nlllossNd_no_reduce_test(),
1365
    nlllossNd_no_reduce_weights_test(),
1366
    nlllossNd_no_reduce_ignore_index_test(),
1367
    smoothl1loss_no_reduce_test(),
1368
    smoothl1loss_no_reduce_scalar_test(),
1369
    multilabelmarginloss_0d_no_reduce_test(),
1370
    multilabelmarginloss_1d_no_reduce_test(),
1371
    multilabelmarginloss_index_neg_test(),
1372
    multilabelmarginloss_no_reduce_test(),
1373
    hingeembeddingloss_no_reduce_test(),
1374
    hingeembeddingloss_margin_no_reduce_test(),
1375
    softmarginloss_no_reduce_test(),
1376
    multilabelsoftmarginloss_no_reduce_test(),
1377
    multilabelsoftmarginloss_weights_no_reduce_test(),
1378
    multimarginloss_no_reduce_test(),
1379
    multimarginloss_1d_no_reduce_test(),
1380
    multimarginloss_1d_input_0d_target_no_reduce_test(),
1381
    multimarginloss_p_no_reduce_test(),
1382
    multimarginloss_margin_no_reduce_test(),
1383
    multimarginloss_weights_no_reduce_test(),
1384
    fractional_max_pool2d_test("ratio"),
1385
    fractional_max_pool2d_test("size"),
1386
    fractional_max_pool3d_test("ratio"),
1387
    fractional_max_pool3d_test("size"),
1388
    fractional_max_pool3d_test("asymsize"),
1389
    dict(
1390
        module_name="BatchNorm1d",
1391
        constructor_args=(10,),
1392
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(10)",
1393
        input_size=(4, 10),
1394
        cudnn=True,
1395
        check_eval=True,
1396
        desc="affine",
1397
        test_cuda=(not TEST_WITH_ROCM),
1398
        pickle=False,
1399
    ),
1400
    dict(
1401
        module_name="BatchNorm1d",
1402
        constructor_args=(5,),
1403
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(5)",
1404
        input_size=(4, 5, 3),
1405
        cudnn=True,
1406
        check_eval=True,
1407
        desc="3d_input",
1408
        pickle=False,
1409
    ),
1410
    dict(
1411
        module_name="BatchNorm1d",
1412
        constructor_args=(10, 1e-3, None),
1413
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(c10::nullopt)",
1414
        input_size=(4, 10),
1415
        cudnn=True,
1416
        check_eval=True,
1417
        desc="affine_simple_average",
1418
        test_cuda=(not TEST_WITH_ROCM),
1419
        pickle=False,
1420
    ),
1421
    dict(
1422
        module_name="BatchNorm1d",
1423
        constructor_args=(10, 1e-3, 0.3, False),
1424
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(0.3).affine(false)",
1425
        input_size=(4, 10),
1426
        cudnn=True,
1427
        check_eval=True,
1428
        desc="not_affine",
1429
        pickle=False,
1430
    ),
1431
    dict(
1432
        module_name="BatchNorm1d",
1433
        constructor_args=(10, 1e-3, 0.3, True, False),
1434
        cpp_constructor_args="""torch::nn::BatchNorm1dOptions(10)
1435
                                .eps(1e-3).momentum(0.3).affine(true).track_running_stats(false)""",
1436
        input_size=(4, 10),
1437
        cudnn=True,
1438
        check_eval=True,
1439
        desc="not_tracking_stats",
1440
        test_cuda=(not TEST_WITH_ROCM),
1441
        pickle=False,
1442
    ),
1443
    dict(
1444
        module_name="BatchNorm1d",
1445
        constructor_args=(5, 1e-3, 0.3, False),
1446
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1447
        input_size=(4, 5, 3),
1448
        cudnn=True,
1449
        check_eval=True,
1450
        desc="3d_input_not_affine",
1451
        pickle=False,
1452
    ),
1453
    dict(
1454
        module_name="BatchNorm1d",
1455
        constructor_args=(5, 1e-3, 0.3, False),
1456
        cpp_constructor_args="torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1457
        input_size=(0, 5, 9),
1458
        cudnn=True,
1459
        check_eval=True,
1460
        desc="zero_batch",
1461
        pickle=False,
1462
    ),
1463
    dict(
1464
        module_name="BatchNorm2d",
1465
        constructor_args=(3,),
1466
        cpp_constructor_args="torch::nn::BatchNorm2dOptions(3)",
1467
        input_size=(2, 3, 6, 6),
1468
        cudnn=True,
1469
        check_eval=True,
1470
        pickle=False,
1471
    ),
1472
    dict(
1473
        module_name="BatchNorm2d",
1474
        constructor_args=(3, 1e-3, None),
1475
        cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(c10::nullopt)",
1476
        input_size=(2, 3, 6, 6),
1477
        cudnn=True,
1478
        check_eval=True,
1479
        desc="2d_simple_average",
1480
        pickle=False,
1481
    ),
1482
    dict(
1483
        module_name="BatchNorm2d",
1484
        constructor_args=(3, 1e-3, 0.8),
1485
        cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8)",
1486
        input_size=(2, 3, 6, 6),
1487
        cudnn=True,
1488
        check_eval=True,
1489
        desc="momentum",
1490
        pickle=False,
1491
    ),
1492
    dict(
1493
        module_name="BatchNorm2d",
1494
        constructor_args=(3, 1e-3, 0.8, False),
1495
        cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8).affine(false)",
1496
        input_size=(2, 3, 6, 6),
1497
        cudnn=True,
1498
        check_eval=True,
1499
        desc="not_affine",
1500
        pickle=False,
1501
    ),
1502
    dict(
1503
        module_name="BatchNorm2d",
1504
        constructor_args=(3, 1e-3, 0.8, True, False),
1505
        cpp_constructor_args="""torch::nn::BatchNorm2dOptions(3)
1506
                                .eps(1e-3).momentum(0.8).affine(true).track_running_stats(false)""",
1507
        input_size=(2, 3, 6, 6),
1508
        cudnn=True,
1509
        check_eval=True,
1510
        desc="not_tracking_stats",
1511
        pickle=False,
1512
    ),
1513
    dict(
1514
        module_name="BatchNorm2d",
1515
        constructor_args=(5, 1e-3, 0.3, False),
1516
        cpp_constructor_args="torch::nn::BatchNorm2dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1517
        input_size=(0, 5, 2, 2),
1518
        cudnn=True,
1519
        check_eval=True,
1520
        desc="zero_batch",
1521
        pickle=False,
1522
    ),
1523
    dict(
1524
        module_name="BatchNorm3d",
1525
        constructor_args=(3,),
1526
        cpp_constructor_args="torch::nn::BatchNorm3dOptions(3)",
1527
        input_size=(2, 3, 4, 4, 4),
1528
        cudnn=True,
1529
        check_eval=True,
1530
        pickle=False,
1531
    ),
1532
    dict(
1533
        module_name="BatchNorm3d",
1534
        constructor_args=(3, 1e-3, None),
1535
        cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(c10::nullopt)",
1536
        input_size=(2, 3, 4, 4, 4),
1537
        cudnn=True,
1538
        check_eval=True,
1539
        desc="3d_simple_average",
1540
        pickle=False,
1541
    ),
1542
    dict(
1543
        module_name="BatchNorm3d",
1544
        constructor_args=(3, 1e-3, 0.7),
1545
        cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7)",
1546
        input_size=(2, 3, 4, 4, 4),
1547
        cudnn=True,
1548
        check_eval=True,
1549
        desc="momentum",
1550
        pickle=False,
1551
    ),
1552
    dict(
1553
        module_name="BatchNorm3d",
1554
        constructor_args=(3, 1e-3, 0.7, False),
1555
        cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7).affine(false)",
1556
        input_size=(2, 3, 4, 4, 4),
1557
        cudnn=True,
1558
        check_eval=True,
1559
        desc="not_affine",
1560
        pickle=False,
1561
    ),
1562
    dict(
1563
        module_name="BatchNorm3d",
1564
        constructor_args=(3, 1e-3, 0.7, True, False),
1565
        cpp_constructor_args="""torch::nn::BatchNorm3dOptions(3)
1566
                                .eps(1e-3).momentum(0.7).affine(true).track_running_stats(false)""",
1567
        input_size=(2, 3, 4, 4, 4),
1568
        cudnn=True,
1569
        check_eval=True,
1570
        desc="not_tracking_stats",
1571
        pickle=False,
1572
    ),
1573
    dict(
1574
        module_name="BatchNorm3d",
1575
        constructor_args=(5, 1e-3, 0.3, False),
1576
        cpp_constructor_args="torch::nn::BatchNorm3dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1577
        input_size=(0, 5, 2, 2, 2),
1578
        cudnn=True,
1579
        check_eval=True,
1580
        desc="zero_batch",
1581
        pickle=False,
1582
    ),
1583
    dict(
1584
        module_name="InstanceNorm1d",
1585
        constructor_args=(3, 1e-3, 0.3),
1586
        cpp_constructor_args="torch::nn::InstanceNorm1dOptions(3).eps(1e-3).momentum(0.3)",
1587
        input_size=(4, 3, 15),
1588
        cudnn=True,
1589
        check_eval=True,
1590
        pickle=False,
1591
    ),
1592
    dict(
1593
        module_name="InstanceNorm1d",
1594
        constructor_args=(3, 1e-3, 0.3, False, True),
1595
        cpp_constructor_args="""torch::nn::InstanceNorm1dOptions(3)
1596
                                .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1597
        input_size=(4, 3, 15),
1598
        cudnn=True,
1599
        check_eval=True,
1600
        desc="tracking_stats",
1601
        pickle=False,
1602
    ),
1603
    dict(
1604
        module_name="InstanceNorm2d",
1605
        constructor_args=(3, 1e-3, 0.3),
1606
        cpp_constructor_args="torch::nn::InstanceNorm2dOptions(3).eps(1e-3).momentum(0.3)",
1607
        input_size=(2, 3, 6, 6),
1608
        cudnn=True,
1609
        check_eval=True,
1610
        pickle=False,
1611
    ),
1612
    dict(
1613
        module_name="InstanceNorm2d",
1614
        constructor_args=(3, 1e-3, 0.3, False, True),
1615
        cpp_constructor_args="""torch::nn::InstanceNorm2dOptions(3)
1616
                                .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1617
        input_size=(2, 3, 6, 6),
1618
        cudnn=True,
1619
        check_eval=True,
1620
        desc="tracking_stats",
1621
        pickle=False,
1622
    ),
1623
    dict(
1624
        module_name="InstanceNorm3d",
1625
        constructor_args=(3, 1e-3, 0.3),
1626
        cpp_constructor_args="torch::nn::InstanceNorm3dOptions(3).eps(1e-3).momentum(0.3)",
1627
        input_size=(2, 3, 4, 4, 4),
1628
        cudnn=True,
1629
        check_eval=True,
1630
        pickle=False,
1631
    ),
1632
    dict(
1633
        module_name="InstanceNorm3d",
1634
        constructor_args=(3, 1e-3, 0.3, False, True),
1635
        cpp_constructor_args="""torch::nn::InstanceNorm3dOptions(3)
1636
                                .eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1637
        input_size=(2, 3, 4, 4, 4),
1638
        cudnn=True,
1639
        check_eval=True,
1640
        desc="tracking_stats",
1641
        pickle=False,
1642
    ),
1643
    dict(
1644
        module_name="LayerNorm",
1645
        constructor_args=([5], 1e-3),
1646
        cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3)",
1647
        input_size=(4, 5, 5),
1648
        cudnn=True,
1649
        check_eval=True,
1650
        desc="1d_elementwise_affine",
1651
    ),
1652
    dict(
1653
        module_name="LayerNorm",
1654
        constructor_args=([5], 1e-3, False),
1655
        cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3).elementwise_affine(false)",
1656
        input_size=(4, 5, 5),
1657
        cudnn=True,
1658
        check_eval=True,
1659
        desc="1d_no_elementwise_affine",
1660
    ),
1661
    dict(
1662
        module_name="LayerNorm",
1663
        constructor_args=([2, 2, 5], 1e-3),
1664
        cpp_constructor_args="torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3)",
1665
        input_size=(4, 2, 2, 5),
1666
        cudnn=True,
1667
        check_eval=True,
1668
        desc="3d_elementwise_affine",
1669
    ),
1670
    dict(
1671
        module_name="LayerNorm",
1672
        constructor_args=([2, 2, 5], 1e-3, False),
1673
        cpp_constructor_args="torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3).elementwise_affine(false)",
1674
        input_size=(4, 2, 2, 5),
1675
        cudnn=True,
1676
        check_eval=True,
1677
        desc="3d_no_elementwise_affine",
1678
    ),
1679
    dict(
1680
        module_name="LayerNorm",
1681
        constructor_args=([5], 1e-3),
1682
        cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3)",
1683
        input_size=(0, 5),
1684
        cudnn=True,
1685
        check_eval=True,
1686
        desc="1d_empty_elementwise_affine",
1687
    ),
1688
    dict(
1689
        module_name="GroupNorm",
1690
        constructor_args=(3, 6, 1e-3),
1691
        cpp_constructor_args="torch::nn::GroupNormOptions(3, 6).eps(1e-3)",
1692
        input_size=(4, 6, 5),
1693
        cudnn=True,
1694
        check_eval=True,
1695
        desc="1d_affine",
1696
    ),
1697
    dict(
1698
        module_name="GroupNorm",
1699
        constructor_args=(5, 5, 1e-3, False),
1700
        cpp_constructor_args="torch::nn::GroupNormOptions(5, 5).eps(1e-3).affine(false)",
1701
        input_size=(4, 5, 5),
1702
        cudnn=True,
1703
        check_eval=True,
1704
        desc="1d_no_affine_IN",  # this setting is equivalent with InstanceNormi
1705
    ),
1706
    dict(
1707
        module_name="GroupNorm",
1708
        constructor_args=(1, 5, 1e-3, False),
1709
        cpp_constructor_args="torch::nn::GroupNormOptions(1, 5).eps(1e-3).affine(false)",
1710
        input_size=(4, 5, 5),
1711
        cudnn=True,
1712
        check_eval=True,
1713
        desc="1d_no_affine_LN",  # this setting is equivalent with LayerNorm
1714
    ),
1715
    dict(
1716
        module_name="GroupNorm",
1717
        constructor_args=(3, 6, 1e-3),
1718
        cpp_constructor_args="torch::nn::GroupNormOptions(3, 6).eps(1e-3)",
1719
        input_size=(4, 6, 2, 3),
1720
        cudnn=True,
1721
        check_eval=True,
1722
        desc="2d_affine",
1723
    ),
1724
    dict(
1725
        module_name="GroupNorm",
1726
        constructor_args=(3, 3, 1e-3, False),
1727
        cpp_constructor_args="torch::nn::GroupNormOptions(3, 3).eps(1e-3).affine(false)",
1728
        input_size=(4, 3, 2, 3),
1729
        cudnn=True,
1730
        check_eval=True,
1731
        desc="2d_no_affine_IN",  # this setting is equivalent with InstanceNorm
1732
    ),
1733
    dict(
1734
        module_name="GroupNorm",
1735
        constructor_args=(1, 3, 1e-3, False),
1736
        cpp_constructor_args="torch::nn::GroupNormOptions(1, 3).eps(1e-3).affine(false)",
1737
        input_size=(4, 3, 2, 3),
1738
        cudnn=True,
1739
        check_eval=True,
1740
        desc="2d_no_affine_LN",  # this setting is equivalent with LayerNorm
1741
    ),
1742
    dict(
1743
        module_name="Conv1d",
1744
        constructor_args=(4, 5, 3),
1745
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3)",
1746
        input_size=(2, 4, 10),
1747
        cudnn=True,
1748
    ),
1749
    dict(
1750
        module_name="Conv1d",
1751
        constructor_args=(4, 5, 3, 2),
1752
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).stride(2)",
1753
        input_size=(2, 4, 10),
1754
        cudnn=True,
1755
        desc="stride",
1756
    ),
1757
    dict(
1758
        module_name="Conv1d",
1759
        constructor_args=(4, 5, 3, 1, 1),
1760
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)",
1761
        input_size=(2, 4, 10),
1762
        cudnn=True,
1763
        desc="pad1",
1764
    ),
1765
    dict(
1766
        module_name="Conv1d",
1767
        constructor_args=(4, 5, 5, 1, 2),
1768
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)",
1769
        input_size=(2, 4, 10),
1770
        cudnn=True,
1771
        desc="pad2",
1772
    ),
1773
    dict(
1774
        module_name="Conv1d",
1775
        constructor_args=(4, 4, 3, 1, 1),
1776
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)",
1777
        input_size=(1, 4, 1),
1778
        cudnn=True,
1779
        desc="pad1size1",
1780
    ),
1781
    dict(
1782
        module_name="Conv1d",
1783
        constructor_args=(4, 4, 5, 1, 2),
1784
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)",
1785
        input_size=(1, 4, 1),
1786
        cudnn=True,
1787
        desc="pad2size1",
1788
    ),
1789
    dict(
1790
        module_name="Conv1d",
1791
        constructor_args=(4, 5, 3),
1792
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3)",
1793
        input_size=(0, 4, 10),
1794
        cudnn=True,
1795
        desc="zero_batch",
1796
        test_cuda=(not TEST_WITH_ROCM),
1797
    ),
1798
    dict(
1799
        fullname="Conv1d_dilated",
1800
        constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
1801
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).dilation(2)",
1802
        input_size=(2, 4, 10),
1803
    ),
1804
    dict(
1805
        fullname="Conv1d_groups",
1806
        constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
1807
        cpp_constructor_args="torch::nn::Conv1dOptions(4, 6, 3).groups(2)",
1808
        input_size=(2, 4, 6),
1809
        cudnn=True,
1810
    ),
1811
    dict(
1812
        fullname="ConvTranspose1d",
1813
        constructor=lambda: nn.ConvTranspose1d(
1814
            3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)
1815
        ),
1816
        cpp_constructor_args="torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)",
1817
        cudnn=True,
1818
        input_size=(1, 3, 7),
1819
    ),
1820
    dict(
1821
        module_name="ConvTranspose1d",
1822
        constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
1823
        cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(3, 4, 3)
1824
                                .stride(2).padding(1).output_padding(1).groups(1).bias(false)""",
1825
        input_size=(1, 3, 6),
1826
        cudnn=True,
1827
        desc="no_bias",
1828
    ),
1829
    dict(
1830
        module_name="ConvTranspose1d",
1831
        constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
1832
        cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(3, 4, 3)
1833
                                .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)""",
1834
        input_size=(1, 3, 6),
1835
        cudnn=True,
1836
        desc="dilated",
1837
    ),
1838
    dict(
1839
        fullname="ConvTranspose1d_groups",
1840
        constructor=lambda: nn.ConvTranspose1d(
1841
            4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2
1842
        ),
1843
        cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(4, 6, 3)
1844
                                .stride(3).padding(1).output_padding(1).groups(2)""",
1845
        cudnn=True,
1846
        input_size=(2, 4, 7),
1847
    ),
1848
    dict(
1849
        module_name="MaxPool1d",
1850
        constructor_args=(4,),
1851
        cpp_constructor_args="torch::nn::MaxPool1dOptions(4)",
1852
        input_size=(2, 10, 4),
1853
    ),
1854
    dict(
1855
        module_name="MaxPool1d",
1856
        constructor_args=(4, 4),
1857
        cpp_constructor_args="torch::nn::MaxPool1dOptions(4).stride(4)",
1858
        input_size=(2, 10, 4),
1859
        desc="stride",
1860
    ),
1861
    dict(
1862
        module_name="Conv2d",
1863
        constructor_args=(3, 4, (3, 2)),
1864
        cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 2})",
1865
        input_size=(2, 3, 7, 5),
1866
        cudnn=True,
1867
        check_with_long_tensor=True,
1868
    ),
1869
    dict(
1870
        module_name="Conv2d",
1871
        constructor_args=(3, 4, (3, 3), (2, 2)),
1872
        cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})",
1873
        input_size=(2, 3, 6, 6),
1874
        cudnn=True,
1875
        desc="strided",
1876
        check_with_long_tensor=True,
1877
    ),
1878
    dict(
1879
        module_name="Conv2d",
1880
        constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
1881
        cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})",
1882
        input_size=(2, 3, 6, 6),
1883
        cudnn=True,
1884
        desc="padding",
1885
        check_with_long_tensor=True,
1886
    ),
1887
    dict(
1888
        module_name="Conv2d",
1889
        constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
1890
        cpp_constructor_args="torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})",
1891
        input_size=(2, 3, 8, 8),
1892
        cudnn=True,
1893
        desc="dilated",
1894
        check_with_long_tensor=True,
1895
    ),
1896
    dict(
1897
        module_name="Conv2d",
1898
        constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
1899
        cpp_constructor_args="""torch::nn::Conv2dOptions(3, 4, {3, 2})
1900
                                .stride(1).padding(0).dilation(1).groups(1).bias(false)""",
1901
        input_size=(2, 3, 6, 5),
1902
        cudnn=True,
1903
        desc="no_bias",
1904
        check_with_long_tensor=True,
1905
    ),
1906
    dict(
1907
        module_name="Conv2d",
1908
        constructor_args=(3, 4, (3, 2)),
1909
        cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 2})",
1910
        input_size=(0, 3, 7, 5),
1911
        cudnn=True,
1912
        desc="zero_batch",
1913
        check_with_long_tensor=True,
1914
        test_cuda=(not TEST_WITH_ROCM),
1915
    ),
1916
    dict(
1917
        fullname="Conv2d_groups",
1918
        constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1919
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)",
1920
        input_size=(2, 4, 6, 5),
1921
        cudnn=True,
1922
        check_with_long_tensor=True,
1923
    ),
1924
    dict(
1925
        fullname="Conv2d_groups_thnn",
1926
        constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1927
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)",
1928
        input_size=(2, 4, 6, 5),
1929
        check_with_long_tensor=True,
1930
    ),
1931
    dict(
1932
        module_name="ConvTranspose2d",
1933
        constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
1934
        cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1935
                                .stride({3, 2}).padding(1).output_padding({1, 1})""",
1936
        cudnn=True,
1937
        input_size=(1, 3, 7, 6),
1938
        check_with_long_tensor=True,
1939
    ),
1940
    dict(
1941
        module_name="ConvTranspose2d",
1942
        constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
1943
        cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1944
                                .stride({2, 3})
1945
                                .padding(1)
1946
                                .output_padding({1, 1})
1947
                                .groups(1)
1948
                                .bias(false)
1949
                                .dilation({2, 2})""",
1950
        input_size=(1, 3, 6, 7),
1951
        cudnn=True,
1952
        desc="dilated",
1953
        check_with_long_tensor=True,
1954
    ),
1955
    dict(
1956
        module_name="ConvTranspose2d",
1957
        constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
1958
        cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1959
                                .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)""",
1960
        input_size=(1, 3, 6, 7),
1961
        cudnn=True,
1962
        desc="no_bias",
1963
        check_with_long_tensor=True,
1964
    ),
1965
    dict(
1966
        fullname="ConvTranspose2d_groups",
1967
        constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
1968
        cpp_constructor_args="torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)",
1969
        input_size=(1, 2, 4, 5),
1970
        cudnn=True,
1971
        check_with_long_tensor=True,
1972
    ),
1973
    dict(
1974
        fullname="Conv2d_depthwise",
1975
        constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
1976
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)",
1977
        input_size=(2, 4, 6, 6),
1978
    ),
1979
    dict(
1980
        fullname="Conv2d_depthwise_with_multiplier",
1981
        constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
1982
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)",
1983
        input_size=(2, 4, 6, 6),
1984
    ),
1985
    dict(
1986
        fullname="Conv2d_depthwise_strided",
1987
        constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
1988
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)",
1989
        input_size=(2, 4, 6, 6),
1990
    ),
1991
    dict(
1992
        fullname="Conv2d_depthwise_padded",
1993
        constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
1994
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)",
1995
        input_size=(2, 4, 6, 6),
1996
    ),
1997
    dict(
1998
        fullname="Conv2d_depthwise_dilated",
1999
        constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
2000
        cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)",
2001
        input_size=(2, 4, 5, 5),
2002
    ),
2003
    dict(
2004
        module_name="MaxPool2d",
2005
        constructor_args=((3, 3), (2, 2), (1, 1)),
2006
        cpp_constructor_args="torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})",
2007
        input_size=(3, 7, 7),
2008
        desc="3d_input",
2009
        check_gradgrad=False,
2010
    ),
2011
    dict(
2012
        module_name="MaxPool2d",
2013
        constructor_args=((3, 3), (2, 2), (1, 1)),
2014
        cpp_constructor_args="torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})",
2015
        input_size=(1, 3, 7, 7),
2016
        check_with_channels_last=True,
2017
        desc="4d_input",
2018
        check_gradgrad=False,
2019
    ),
2020
    dict(
2021
        module_name="AvgPool1d",
2022
        constructor_args=(2,),
2023
        cpp_constructor_args="torch::nn::AvgPool1dOptions(2)",
2024
        input_size=(2, 3, 6),
2025
    ),
2026
    dict(
2027
        module_name="AvgPool1d",
2028
        constructor_args=((2,), (2,)),
2029
        cpp_constructor_args="torch::nn::AvgPool1dOptions(2).stride(2)",
2030
        input_size=(2, 3, 6),
2031
        desc="stride",
2032
    ),
2033
    dict(
2034
        module_name="AvgPool1d",
2035
        constructor_args=(2, 2, 1),
2036
        cpp_constructor_args="torch::nn::AvgPool1dOptions(2).stride(2).padding(1)",
2037
        input_size=(2, 3, 6),
2038
        desc="stride_pad",
2039
    ),
2040
    dict(
2041
        module_name="AvgPool2d",
2042
        constructor_args=((2, 2),),
2043
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2})",
2044
        input_size=(2, 3, 6, 6),
2045
    ),
2046
    dict(
2047
        module_name="AvgPool2d",
2048
        constructor_args=((2, 2), (2, 2)),
2049
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2})",
2050
        input_size=(2, 3, 6, 6),
2051
        desc="stride",
2052
    ),
2053
    dict(
2054
        module_name="AvgPool2d",
2055
        constructor_args=((2, 2), (2, 2), (1, 1)),
2056
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1})",
2057
        input_size=(2, 3, 6, 6),
2058
        desc="stride_pad",
2059
    ),
2060
    dict(
2061
        fullname="AvgPool2d_divisor",
2062
        constructor=lambda: nn.AvgPool2d((2, 2), divisor_override=1),
2063
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).divisor_override(1)",
2064
        input_size=(2, 3, 6, 6),
2065
        check_with_long_tensor=True,
2066
    ),
2067
    dict(
2068
        fullname="AvgPool2d_divisor_stride",
2069
        constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), divisor_override=1),
2070
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).divisor_override(1)",
2071
        input_size=(2, 3, 6, 6),
2072
        check_with_long_tensor=True,
2073
    ),
2074
    dict(
2075
        fullname="AvgPool2d_divisor_stride_pad",
2076
        constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), (1, 1), divisor_override=1),
2077
        cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1}).divisor_override(1)",
2078
        input_size=(2, 3, 6, 6),
2079
        check_with_long_tensor=True,
2080
    ),
2081
    dict(
2082
        module_name="LPPool2d",
2083
        constructor_args=(2, 2, 2),
2084
        cpp_constructor_args="torch::nn::LPPool2dOptions(2, 2).stride(2)",
2085
        input_size=(1, 3, 7, 7),
2086
    ),
2087
    dict(
2088
        module_name="LPPool2d",
2089
        constructor_args=(1.5, 2),
2090
        cpp_constructor_args="torch::nn::LPPool2dOptions(1.5, 2)",
2091
        input_fn=lambda: torch.rand(1, 3, 7, 7),
2092
        desc="norm",
2093
    ),
2094
    dict(
2095
        module_name="LPPool1d",
2096
        constructor_args=(1.5, 2),
2097
        cpp_constructor_args="torch::nn::LPPool1dOptions(1.5, 2)",
2098
        input_fn=lambda: torch.rand(1, 3, 7),
2099
        desc="norm",
2100
    ),
2101
    dict(
2102
        module_name="LPPool1d",
2103
        constructor_args=(2, 2, 3),
2104
        cpp_constructor_args="torch::nn::LPPool1dOptions(2, 2).stride(3)",
2105
        input_size=(1, 3, 7),
2106
    ),
2107
    dict(
2108
        module_name="LocalResponseNorm",
2109
        constructor_args=(3,),
2110
        cpp_constructor_args="torch::nn::LocalResponseNormOptions(3)",
2111
        input_size=(1, 5, 7),
2112
        desc="1d",
2113
    ),
2114
    dict(
2115
        module_name="LocalResponseNorm",
2116
        constructor_args=(2,),
2117
        cpp_constructor_args="torch::nn::LocalResponseNormOptions(2)",
2118
        input_size=(1, 5, 7, 7),
2119
        desc="2d_uneven_pad",
2120
    ),
2121
    dict(
2122
        module_name="LocalResponseNorm",
2123
        constructor_args=(1, 1.0, 0.5, 2.0),
2124
        cpp_constructor_args="torch::nn::LocalResponseNormOptions(1).alpha(1.).beta(0.5).k(2.)",
2125
        input_size=(1, 5, 7, 7, 7),
2126
        desc="3d_custom_params",
2127
    ),
2128
    dict(
2129
        module_name="ReflectionPad1d",
2130
        constructor_args=((1, 2),),
2131
        cpp_constructor_args="torch::nn::ReflectionPad1dOptions({1, 2})",
2132
        input_size=(2, 3, 8),
2133
    ),
2134
    dict(
2135
        module_name="ReflectionPad2d",
2136
        constructor_args=((1, 2, 3, 4),),
2137
        cpp_constructor_args="torch::nn::ReflectionPad2dOptions({1, 2, 3, 4})",
2138
        input_size=(2, 3, 8, 8),
2139
    ),
2140
    dict(
2141
        module_name="ReplicationPad1d",
2142
        constructor_args=((1, 2),),
2143
        cpp_constructor_args="torch::nn::ReplicationPad1dOptions({1, 2})",
2144
        input_size=(2, 3, 4),
2145
    ),
2146
    dict(
2147
        module_name="ReplicationPad2d",
2148
        constructor_args=((1, 2, 3, 4),),
2149
        cpp_constructor_args="torch::nn::ReplicationPad2dOptions({1, 2, 3, 4})",
2150
        input_size=(2, 3, 4, 4),
2151
    ),
2152
    dict(
2153
        module_name="ZeroPad2d",
2154
        constructor_args=((1, 2, 3, 4),),
2155
        cpp_constructor_args="torch::nn::ZeroPad2dOptions({1, 2, 3, 4})",
2156
        input_size=(2, 3, 4, 4),
2157
    ),
2158
    dict(
2159
        module_name="ZeroPad2d",
2160
        constructor_args=((-1, -1, -1, -2),),
2161
        cpp_constructor_args="torch::nn::ZeroPad2dOptions({-1, -1, -1, -2})",
2162
        input_size=(2, 3, 4, 4),
2163
        desc="negative_dims",
2164
    ),
2165
    dict(
2166
        module_name="ConstantPad1d",
2167
        constructor_args=((1, 2), 2.0),
2168
        cpp_constructor_args="torch::nn::ConstantPad1dOptions({1, 2}, 2.)",
2169
        input_size=(2, 3, 4),
2170
    ),
2171
    dict(
2172
        module_name="ConstantPad2d",
2173
        constructor_args=((1, 2, 3, 4), 2.0),
2174
        cpp_constructor_args="torch::nn::ConstantPad2dOptions({1, 2, 3, 4}, 2.)",
2175
        input_size=(2, 3, 4, 4),
2176
    ),
2177
    dict(
2178
        module_name="ConstantPad3d",
2179
        constructor_args=((1, 2, 3, 4, 1, 0), 2.0),
2180
        cpp_constructor_args="torch::nn::ConstantPad3dOptions({1, 2, 3, 4, 1, 0}, 2.)",
2181
        input_size=(2, 3, 4, 4, 5),
2182
    ),
2183
    dict(
2184
        module_name="Conv3d",
2185
        constructor_args=(3, 4, (2, 3, 4)),
2186
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, {2, 3, 4})",
2187
        input_size=(2, 3, 3, 4, 5),
2188
        cudnn=True,
2189
        check_with_long_tensor=True,
2190
    ),
2191
    dict(
2192
        module_name="Conv3d",
2193
        constructor_args=(3, 4, (2, 3, 4), 1, 0, 1, 1, False),
2194
        cpp_constructor_args="""torch::nn::Conv3dOptions(3, 4, {2, 3, 4})
2195
                                .stride(1).padding(0).dilation(1).groups(1).bias(false)""",
2196
        input_size=(2, 3, 3, 4, 5),
2197
        cudnn=True,
2198
        desc="no_bias",
2199
        check_with_long_tensor=True,
2200
    ),
2201
    dict(
2202
        module_name="Conv3d",
2203
        constructor_args=(3, 4, 2, 2),
2204
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).stride(2)",
2205
        input_size=(2, 3, 5, 5, 5),
2206
        cudnn=True,
2207
        desc="stride",
2208
        check_with_long_tensor=True,
2209
    ),
2210
    dict(
2211
        module_name="Conv3d",
2212
        constructor_args=(3, 4, 2, 2, 1),
2213
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)",
2214
        input_size=(2, 3, 5, 5, 5),
2215
        cudnn=True,
2216
        desc="stride_padding",
2217
        check_with_long_tensor=True,
2218
    ),
2219
    dict(
2220
        module_name="Conv3d",
2221
        constructor_args=(3, 4, (2, 3, 4)),
2222
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, {2, 3, 4})",
2223
        input_size=(0, 3, 3, 4, 5),
2224
        cudnn=True,
2225
        check_with_long_tensor=True,
2226
        desc="zero_batch",
2227
        test_cuda=(not TEST_WITH_ROCM),
2228
    ),
2229
    dict(
2230
        fullname="Conv3d_groups",
2231
        constructor=lambda: nn.Conv3d(4, 6, kernel_size=3, groups=2),
2232
        cpp_constructor_args="torch::nn::Conv3dOptions(4, 6, 3).groups(2)",
2233
        input_size=(2, 4, 4, 5, 4),
2234
        cudnn=True,
2235
        check_with_long_tensor=True,
2236
    ),
2237
    dict(
2238
        fullname="Conv3d_dilated",
2239
        constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
2240
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).dilation(2)",
2241
        input_size=(2, 3, 5, 5, 5),
2242
    ),
2243
    dict(
2244
        fullname="Conv3d_dilated_strided",
2245
        constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
2246
        cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)",
2247
        input_size=(2, 3, 5, 5, 5),
2248
    ),
2249
    dict(
2250
        module_name="ConvTranspose3d",
2251
        constructor_args=(2, 3, (2, 3, 2)),
2252
        cpp_constructor_args="torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})",
2253
        cudnn=True,
2254
        input_size=(1, 2, 4, 5, 4),
2255
    ),
2256
    dict(
2257
        module_name="ConvTranspose3d",
2258
        constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
2259
        cpp_constructor_args="""torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
2260
                                .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})""",
2261
        cudnn=True,
2262
        input_size=(1, 2, 4, 5, 4),
2263
        desc="dilated",
2264
    ),
2265
    dict(
2266
        module_name="MaxPool3d",
2267
        constructor_args=((2, 2, 2),),
2268
        cpp_constructor_args="torch::nn::MaxPool3dOptions({2, 2, 2})",
2269
        input_size=(2, 3, 5, 5, 5),
2270
        check_gradgrad=False,
2271
    ),
2272
    dict(
2273
        module_name="MaxPool3d",
2274
        constructor_args=(2, (2, 2, 2)),
2275
        cpp_constructor_args="torch::nn::MaxPool3dOptions(2).stride({2, 2, 2})",
2276
        input_size=(2, 3, 5, 5, 5),
2277
        desc="stride",
2278
        check_gradgrad=False,
2279
    ),
2280
    dict(
2281
        module_name="MaxPool3d",
2282
        constructor_args=(2, 2, (1, 1, 1)),
2283
        cpp_constructor_args="torch::nn::MaxPool3dOptions(2).stride(2).padding({1, 1, 1})",
2284
        input_size=(2, 3, 5, 5, 5),
2285
        desc="stride_padding",
2286
        check_gradgrad=False,
2287
    ),
2288
    dict(
2289
        module_name="AvgPool3d",
2290
        constructor_args=((2, 2, 2),),
2291
        cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 2, 2})",
2292
        input_size=(2, 3, 4, 4, 4),
2293
    ),
2294
    dict(
2295
        module_name="AvgPool3d",
2296
        constructor_args=(2, (2, 2, 2)),
2297
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride({2, 2, 2})",
2298
        input_size=(2, 3, 5, 5, 5),
2299
        desc="stride",
2300
    ),
2301
    dict(
2302
        module_name="AvgPool3d",
2303
        constructor_args=(2, 2, (1, 1, 1)),
2304
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})",
2305
        input_size=(2, 3, 5, 5, 5),
2306
        desc="stride_pad",
2307
    ),
2308
    dict(
2309
        module_name="AvgPool3d",
2310
        constructor_args=(4, 2, (1, 2, 1)),
2311
        cpp_constructor_args="torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1})",
2312
        input_size=(2, 3, 5, 5, 5),
2313
        desc="stride_pad_gpu_fixedkw_output",
2314
    ),
2315
    dict(
2316
        module_name="AvgPool3d",
2317
        constructor_args=((2, 4, 8), 1, (1, 1, 2)),
2318
        cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2})",
2319
        input_size=(2, 3, 2, 4, 8),
2320
        desc="stride_pad_gpu_general_output",
2321
    ),
2322
    dict(
2323
        module_name="AvgPool3d",
2324
        constructor_args=(3, 1, 0),
2325
        cpp_constructor_args="torch::nn::AvgPool3dOptions(3).stride(1).padding(0)",
2326
        input_size=(2, 3, 4, 4, 4),
2327
        desc="stride1_pad0_gpu_input",
2328
    ),
2329
    dict(
2330
        module_name="AvgPool3d",
2331
        constructor_args=(2, 2, (1, 1, 1)),
2332
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})",
2333
        input_size=(2, 3, 4, 4, 4),
2334
        desc="stride_pad_gpu_input_nooverlap",
2335
    ),
2336
    dict(
2337
        fullname="AvgPool3d_divisor",
2338
        constructor=lambda: nn.AvgPool3d((2, 2, 2), divisor_override=1),
2339
        cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 2, 2}).divisor_override(1)",
2340
        input_size=(2, 3, 4, 4, 4),
2341
        check_with_long_tensor=True,
2342
    ),
2343
    dict(
2344
        fullname="AvgPool3d_divisor_stride",
2345
        constructor=lambda: nn.AvgPool3d(2, (2, 2, 2), divisor_override=1),
2346
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride({2, 2, 2}).divisor_override(1)",
2347
        input_size=(2, 3, 5, 5, 5),
2348
        check_with_long_tensor=True,
2349
    ),
2350
    dict(
2351
        fullname="AvgPool3d_divisor_stride_pad",
2352
        constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
2353
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)",
2354
        input_size=(2, 3, 5, 5, 5),
2355
        check_with_long_tensor=True,
2356
    ),
2357
    dict(
2358
        fullname="AvgPool3d_divisor_stride_pad_gpu_fixedkw_output",
2359
        constructor=lambda: nn.AvgPool3d(4, 2, (1, 2, 1), divisor_override=1),
2360
        cpp_constructor_args="torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1}).divisor_override(1)",
2361
        input_size=(2, 3, 5, 5, 5),
2362
        check_with_long_tensor=True,
2363
    ),
2364
    dict(
2365
        fullname="AvgPool3d_divisor_stride_pad_gpu_general_output",
2366
        constructor=lambda: nn.AvgPool3d((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
2367
        cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2}).divisor_override(1)",
2368
        input_size=(2, 3, 2, 4, 8),
2369
        check_with_long_tensor=True,
2370
    ),
2371
    dict(
2372
        fullname="AvgPool3d_divisor_stride1_pad0_gpu_input",
2373
        constructor=lambda: nn.AvgPool3d(3, 1, 0, divisor_override=1),
2374
        cpp_constructor_args="torch::nn::AvgPool3dOptions(3).stride(1).padding(0).divisor_override(1)",
2375
        input_size=(2, 3, 4, 4, 4),
2376
        check_with_long_tensor=True,
2377
    ),
2378
    dict(
2379
        fullname="AvgPool3d_divisor_stride_pad_gpu_input_nooverlap",
2380
        constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
2381
        cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)",
2382
        input_size=(2, 3, 4, 4, 4),
2383
        check_with_long_tensor=True,
2384
    ),
2385
    dict(
2386
        module_name="ReplicationPad3d",
2387
        constructor_args=((1, 2, 3, 4, 5, 6),),
2388
        cpp_constructor_args="torch::nn::ReplicationPad3dOptions({1, 2, 3, 4, 5, 6})",
2389
        input_size=(2, 3, 5, 5, 5),
2390
    ),
2391
    dict(
2392
        module_name="Embedding",
2393
        constructor_args=(4, 3),
2394
        cpp_constructor_args="torch::nn::EmbeddingOptions(4, 3)",
2395
        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2396
        jacobian_input=False,
2397
        check_gradgrad=False,
2398
    ),
2399
    dict(
2400
        module_name="EmbeddingBag",
2401
        constructor_args=(4, 3),
2402
        cpp_constructor_args="torch::nn::EmbeddingBagOptions(4, 3)",
2403
        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2404
        jacobian_input=False,
2405
        check_gradgrad=False,
2406
        check_forward_only=True,
2407
        desc="mean",
2408
    ),
2409
    dict(
2410
        module_name="EmbeddingBag",
2411
        constructor_args=(4, 3, None, 2.0, False, "sum"),
2412
        cpp_constructor_args="""torch::nn::EmbeddingBagOptions(4, 3)
2413
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)""",
2414
        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2415
        jacobian_input=False,
2416
        check_gradgrad=False,
2417
        check_forward_only=True,
2418
        desc="sum",
2419
    ),
2420
    dict(
2421
        module_name="EmbeddingBag",
2422
        constructor_args=(4, 3, None, 2.0, False, "max"),
2423
        cpp_constructor_args="""torch::nn::EmbeddingBagOptions(4, 3)
2424
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)""",
2425
        input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2426
        jacobian_input=False,
2427
        check_gradgrad=False,
2428
        check_forward_only=True,
2429
        desc="max",
2430
    ),
2431
    dict(
2432
        fullname="EmbeddingBag_sparse",
2433
        constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
2434
        cpp_constructor_args="torch::nn::EmbeddingBagOptions(4, 3).sparse(true)",
2435
        input_fn=lambda: torch.randperm(2).repeat(1, 2),
2436
        jacobian_input=False,
2437
        check_gradgrad=False,
2438
    ),
2439
    dict(
2440
        constructor=lambda: nn.Embedding(4, 3, sparse=True),
2441
        cpp_constructor_args="torch::nn::EmbeddingOptions(4, 3).sparse(true)",
2442
        input_fn=lambda: torch.randperm(2).repeat(1, 2),
2443
        jacobian_input=False,
2444
        fullname="Embedding_sparse",
2445
        check_gradgrad=False,
2446
    ),
2447
    dict(
2448
        module_name="PixelShuffle",
2449
        constructor_args=(3,),
2450
        cpp_constructor_args="torch::nn::PixelShuffleOptions(3)",
2451
        input_size=(1, 9, 4, 4),
2452
    ),
2453
    dict(
2454
        constructor=wrap_functional(
2455
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2456
        ),
2457
        cpp_options_args="""F::InterpolateFuncOptions()
2458
                            .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2459
        input_size=(1, 2, 4),
2460
        fullname="interpolate_nearest_1d",
2461
        pickle=False,
2462
    ),
2463
    dict(
2464
        constructor=wrap_functional(
2465
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2466
        ),
2467
        cpp_options_args="""F::InterpolateFuncOptions()
2468
                            .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2469
        input_size=(0, 2, 4),
2470
        fullname="interpolate_nearest_1d_zero_dim",
2471
        pickle=False,
2472
    ),
2473
    dict(
2474
        constructor=wrap_functional(
2475
            F.interpolate, size=(12,), scale_factor=None, mode="nearest"
2476
        ),
2477
        cpp_options_args="""F::InterpolateFuncOptions()
2478
                            .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2479
        input_size=(1, 2, 3),
2480
        fullname="interpolate_nearest_tuple_1d",
2481
        pickle=False,
2482
    ),
2483
    dict(
2484
        constructor=wrap_functional(
2485
            F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2486
        ),
2487
        cpp_options_args="""F::InterpolateFuncOptions()
2488
                            .size(c10::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)""",
2489
        input_size=(1, 2, 4),
2490
        fullname="interpolate_nearest_scale_1d",
2491
        pickle=False,
2492
    ),
2493
    dict(
2494
        constructor=wrap_functional(
2495
            F.interpolate,
2496
            size=12,
2497
            scale_factor=None,
2498
            mode="linear",
2499
            align_corners=False,
2500
        ),
2501
        cpp_options_args="""F::InterpolateFuncOptions()
2502
                            .size(std::vector<int64_t>({12}))
2503
                            .scale_factor(c10::nullopt)
2504
                            .mode(torch::kLinear)
2505
                            .align_corners(false)""",
2506
        input_size=(1, 2, 4),
2507
        fullname="interpolate_linear_1d",
2508
        pickle=False,
2509
    ),
2510
    dict(
2511
        constructor=wrap_functional(
2512
            F.interpolate,
2513
            size=(4,),
2514
            scale_factor=None,
2515
            mode="linear",
2516
            align_corners=False,
2517
        ),
2518
        cpp_options_args="""F::InterpolateFuncOptions()
2519
                            .size(std::vector<int64_t>({4}))
2520
                            .scale_factor(c10::nullopt)
2521
                            .mode(torch::kLinear)
2522
                            .align_corners(false)""",
2523
        input_size=(1, 2, 3),
2524
        fullname="interpolate_linear_tuple_1d",
2525
        pickle=False,
2526
    ),
2527
    dict(
2528
        constructor=wrap_functional(
2529
            F.interpolate,
2530
            size=None,
2531
            scale_factor=4.0,
2532
            mode="linear",
2533
            align_corners=False,
2534
        ),
2535
        cpp_options_args="""F::InterpolateFuncOptions()
2536
                            .size(c10::nullopt)
2537
                            .scale_factor(std::vector<double>({4.}))
2538
                            .mode(torch::kLinear)
2539
                            .align_corners(false)""",
2540
        input_size=(1, 2, 4),
2541
        fullname="interpolate_linear_scale_1d",
2542
        pickle=False,
2543
    ),
2544
    dict(
2545
        constructor=wrap_functional(
2546
            F.interpolate,
2547
            size=12,
2548
            scale_factor=None,
2549
            mode="linear",
2550
            align_corners=False,
2551
        ),
2552
        cpp_options_args="""F::InterpolateFuncOptions()
2553
                            .size(std::vector<int64_t>({12}))
2554
                            .scale_factor(c10::nullopt)
2555
                            .mode(torch::kLinear)
2556
                            .align_corners(false)""",
2557
        input_size=(0, 2, 4),
2558
        fullname="interpolate_linear_1d_zero_dim",
2559
        pickle=False,
2560
    ),
2561
    dict(
2562
        constructor=wrap_functional(
2563
            F.interpolate, size=12, scale_factor=None, mode="linear", align_corners=True
2564
        ),
2565
        cpp_options_args="""F::InterpolateFuncOptions()
2566
                            .size(std::vector<int64_t>({12}))
2567
                            .scale_factor(c10::nullopt)
2568
                            .mode(torch::kLinear)
2569
                            .align_corners(true)""",
2570
        input_size=(1, 2, 4),
2571
        fullname="interpolate_linear_1d_align_corners",
2572
        pickle=False,
2573
    ),
2574
    dict(
2575
        constructor=wrap_functional(
2576
            F.interpolate,
2577
            size=None,
2578
            scale_factor=4.0,
2579
            mode="linear",
2580
            align_corners=True,
2581
        ),
2582
        cpp_options_args="""F::InterpolateFuncOptions()
2583
                            .size(c10::nullopt)
2584
                            .scale_factor(std::vector<double>({4.}))
2585
                            .mode(torch::kLinear)
2586
                            .align_corners(true)""",
2587
        input_size=(1, 2, 4),
2588
        fullname="interpolate_linear_scale_1d_align_corners",
2589
        pickle=False,
2590
    ),
2591
    dict(
2592
        constructor=wrap_functional(
2593
            F.interpolate, size=2, scale_factor=None, mode="nearest"
2594
        ),
2595
        cpp_options_args="""F::InterpolateFuncOptions()
2596
                            .size(std::vector<int64_t>({2, 2}))
2597
                            .scale_factor(c10::nullopt)
2598
                            .mode(torch::kNearest)""",
2599
        input_size=(1, 128, 1, 1),
2600
        fullname="interpolate_nearest_2d_launch_configs",
2601
        pickle=False,
2602
    ),
2603
    dict(
2604
        constructor=wrap_functional(
2605
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2606
        ),
2607
        cpp_options_args="""F::InterpolateFuncOptions()
2608
                            .size(std::vector<int64_t>({12, 12}))
2609
                            .scale_factor(c10::nullopt)
2610
                            .mode(torch::kNearest)""",
2611
        input_size=(1, 2, 4, 4),
2612
        fullname="interpolate_nearest_2d",
2613
        pickle=False,
2614
    ),
2615
    dict(
2616
        constructor=wrap_functional(
2617
            F.interpolate, size=(12, 16), scale_factor=None, mode="nearest"
2618
        ),
2619
        cpp_options_args="""F::InterpolateFuncOptions()
2620
                            .size(std::vector<int64_t>({12, 16}))
2621
                            .scale_factor(c10::nullopt)
2622
                            .mode(torch::kNearest)""",
2623
        input_size=(1, 2, 3, 4),
2624
        fullname="interpolate_nearest_tuple_2d",
2625
        pickle=False,
2626
    ),
2627
    dict(
2628
        constructor=wrap_functional(
2629
            F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2630
        ),
2631
        cpp_options_args="""F::InterpolateFuncOptions()
2632
                            .size(c10::nullopt)
2633
                            .scale_factor(std::vector<double>({4., 4.}))
2634
                            .mode(torch::kNearest)""",
2635
        input_size=(1, 2, 4, 4),
2636
        fullname="interpolate_nearest_scale_2d",
2637
        pickle=False,
2638
    ),
2639
    dict(
2640
        constructor=wrap_functional(
2641
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2642
        ),
2643
        cpp_options_args="""F::InterpolateFuncOptions()
2644
                            .size(std::vector<int64_t>({12, 12}))
2645
                            .scale_factor(c10::nullopt)
2646
                            .mode(torch::kNearest)""",
2647
        input_size=(0, 2, 4, 4),
2648
        fullname="interpolate_nearest_2d_zero_dim",
2649
        pickle=False,
2650
    ),
2651
    dict(
2652
        constructor=wrap_functional(
2653
            F.interpolate,
2654
            size=12,
2655
            scale_factor=None,
2656
            mode="bilinear",
2657
            align_corners=False,
2658
        ),
2659
        cpp_options_args="""F::InterpolateFuncOptions()
2660
                            .size(std::vector<int64_t>({12, 12}))
2661
                            .scale_factor(c10::nullopt)
2662
                            .mode(torch::kBilinear)
2663
                            .align_corners(false)""",
2664
        input_size=(1, 2, 4, 4),
2665
        fullname="interpolate_bilinear_2d",
2666
        pickle=False,
2667
    ),
2668
    dict(
2669
        constructor=wrap_functional(
2670
            F.interpolate,
2671
            size=12,
2672
            scale_factor=None,
2673
            mode="bilinear",
2674
            align_corners=False,
2675
        ),
2676
        cpp_options_args="""F::InterpolateFuncOptions()
2677
                            .size(std::vector<int64_t>({12, 12}))
2678
                            .scale_factor(c10::nullopt)
2679
                            .mode(torch::kBilinear)
2680
                            .align_corners(false)""",
2681
        input_size=(0, 2, 4, 4),
2682
        fullname="interpolate_bilinear_2d_zero_dim",
2683
        pickle=False,
2684
    ),
2685
    dict(
2686
        constructor=wrap_functional(
2687
            F.interpolate,
2688
            size=(4, 6),
2689
            scale_factor=None,
2690
            mode="bilinear",
2691
            align_corners=False,
2692
        ),
2693
        cpp_options_args="""F::InterpolateFuncOptions()
2694
                            .size(std::vector<int64_t>({4, 6}))
2695
                            .scale_factor(c10::nullopt)
2696
                            .mode(torch::kBilinear)
2697
                            .align_corners(false)""",
2698
        input_size=(1, 2, 2, 3),
2699
        fullname="interpolate_bilinear_tuple_2d",
2700
        pickle=False,
2701
    ),
2702
    dict(
2703
        constructor=wrap_functional(
2704
            F.interpolate,
2705
            size=None,
2706
            scale_factor=4.0,
2707
            mode="bilinear",
2708
            align_corners=False,
2709
        ),
2710
        cpp_options_args="""F::InterpolateFuncOptions()
2711
                            .size(c10::nullopt)
2712
                            .scale_factor(std::vector<double>({4., 4.}))
2713
                            .mode(torch::kBilinear)
2714
                            .align_corners(false)""",
2715
        input_size=(1, 2, 4, 4),
2716
        fullname="interpolate_bilinear_scale_2d",
2717
        pickle=False,
2718
    ),
2719
    dict(
2720
        constructor=wrap_functional(
2721
            F.interpolate,
2722
            size=None,
2723
            scale_factor=(2.0, 2.0),
2724
            mode="bilinear",
2725
            align_corners=False,
2726
        ),
2727
        cpp_options_args="""F::InterpolateFuncOptions()
2728
                            .size(c10::nullopt)
2729
                            .scale_factor(std::vector<double>({2., 2.}))
2730
                            .mode(torch::kBilinear)
2731
                            .align_corners(false)""",
2732
        input_size=(1, 2, 4, 4),
2733
        fullname="interpolate_bilinear_scale_tuple_shared_2d",
2734
        pickle=False,
2735
    ),
2736
    dict(
2737
        constructor=wrap_functional(
2738
            F.interpolate,
2739
            size=None,
2740
            scale_factor=(2.0, 1.0),
2741
            mode="bilinear",
2742
            align_corners=False,
2743
        ),
2744
        cpp_options_args="""F::InterpolateFuncOptions()
2745
                            .size(c10::nullopt)
2746
                            .scale_factor(std::vector<double>({2., 1.}))
2747
                            .mode(torch::kBilinear)
2748
                            .align_corners(false)""",
2749
        input_size=(1, 2, 4, 4),
2750
        fullname="interpolate_bilinear_scale_tuple_skewed_2d",
2751
        pickle=False,
2752
    ),
2753
    dict(
2754
        constructor=wrap_functional(
2755
            F.interpolate,
2756
            size=(4, 6),
2757
            scale_factor=None,
2758
            mode="bilinear",
2759
            align_corners=True,
2760
        ),
2761
        cpp_options_args="""F::InterpolateFuncOptions()
2762
                            .size(std::vector<int64_t>({4, 6}))
2763
                            .scale_factor(c10::nullopt)
2764
                            .mode(torch::kBilinear)
2765
                            .align_corners(true)""",
2766
        input_size=(1, 2, 4, 4),
2767
        fullname="interpolate_bilinear_tuple_2d_align_corners",
2768
        pickle=False,
2769
    ),
2770
    dict(
2771
        constructor=wrap_functional(
2772
            F.interpolate,
2773
            size=None,
2774
            scale_factor=(2.0, 1.0),
2775
            mode="bilinear",
2776
            align_corners=True,
2777
        ),
2778
        cpp_options_args="""F::InterpolateFuncOptions()
2779
                            .size(c10::nullopt)
2780
                            .scale_factor(std::vector<double>({2., 1.}))
2781
                            .mode(torch::kBilinear)
2782
                            .align_corners(true)""",
2783
        input_size=(1, 2, 4, 4),
2784
        fullname="interpolate_bilinear_scale_tuple_skewed_2d_align_corners",
2785
        pickle=False,
2786
    ),
2787
    dict(
2788
        constructor=wrap_functional(
2789
            F.interpolate,
2790
            size=12,
2791
            scale_factor=None,
2792
            mode="bicubic",
2793
            align_corners=False,
2794
        ),
2795
        cpp_options_args="""F::InterpolateFuncOptions()
2796
                            .size(std::vector<int64_t>({12, 12}))
2797
                            .scale_factor(c10::nullopt)
2798
                            .mode(torch::kBicubic)
2799
                            .align_corners(false)""",
2800
        input_size=(1, 2, 4, 4),
2801
        fullname="interpolate_bicubic_2d",
2802
        pickle=False,
2803
    ),
2804
    dict(
2805
        constructor=wrap_functional(
2806
            F.interpolate,
2807
            size=12,
2808
            scale_factor=None,
2809
            mode="bicubic",
2810
            align_corners=False,
2811
        ),
2812
        cpp_options_args="""F::InterpolateFuncOptions()
2813
                            .size(std::vector<int64_t>({12, 12}))
2814
                            .scale_factor(c10::nullopt)
2815
                            .mode(torch::kBicubic)
2816
                            .align_corners(false)""",
2817
        input_size=(0, 2, 4, 4),
2818
        fullname="interpolate_bicubic_2d_zero_dim",
2819
        pickle=False,
2820
    ),
2821
    dict(
2822
        constructor=wrap_functional(
2823
            F.interpolate,
2824
            size=(4, 6),
2825
            scale_factor=None,
2826
            mode="bicubic",
2827
            align_corners=False,
2828
        ),
2829
        cpp_options_args="""F::InterpolateFuncOptions()
2830
                            .size(std::vector<int64_t>({4, 6}))
2831
                            .scale_factor(c10::nullopt)
2832
                            .mode(torch::kBicubic)
2833
                            .align_corners(false)""",
2834
        input_size=(1, 2, 2, 3),
2835
        fullname="interpolate_bicubic_tuple_2d",
2836
        pickle=False,
2837
    ),
2838
    dict(
2839
        constructor=wrap_functional(
2840
            F.interpolate,
2841
            size=None,
2842
            scale_factor=4.0,
2843
            mode="bicubic",
2844
            align_corners=False,
2845
        ),
2846
        cpp_options_args="""F::InterpolateFuncOptions()
2847
                            .size(c10::nullopt)
2848
                            .scale_factor(std::vector<double>({4., 4.}))
2849
                            .mode(torch::kBicubic)
2850
                            .align_corners(false)""",
2851
        input_size=(1, 2, 4, 4),
2852
        fullname="interpolate_bicubic_scale_2d",
2853
        pickle=False,
2854
    ),
2855
    dict(
2856
        constructor=wrap_functional(
2857
            F.interpolate,
2858
            size=None,
2859
            scale_factor=(2.0, 2.0),
2860
            mode="bicubic",
2861
            align_corners=False,
2862
        ),
2863
        cpp_options_args="""F::InterpolateFuncOptions()
2864
                            .size(c10::nullopt)
2865
                            .scale_factor(std::vector<double>({2., 2.}))
2866
                            .mode(torch::kBicubic)
2867
                            .align_corners(false)""",
2868
        input_size=(1, 2, 4, 4),
2869
        fullname="interpolate_bicubic_scale_tuple_shared_2d",
2870
        pickle=False,
2871
    ),
2872
    dict(
2873
        constructor=wrap_functional(
2874
            F.interpolate,
2875
            size=None,
2876
            scale_factor=(2.0, 1.0),
2877
            mode="bicubic",
2878
            align_corners=False,
2879
        ),
2880
        cpp_options_args="""F::InterpolateFuncOptions()
2881
                            .size(c10::nullopt)
2882
                            .scale_factor(std::vector<double>({2., 1.}))
2883
                            .mode(torch::kBicubic)
2884
                            .align_corners(false)""",
2885
        input_size=(1, 2, 4, 4),
2886
        fullname="interpolate_bicubic_scale_tuple_skewed_2d",
2887
        pickle=False,
2888
    ),
2889
    dict(
2890
        constructor=wrap_functional(
2891
            F.interpolate,
2892
            size=(4, 6),
2893
            scale_factor=None,
2894
            mode="bicubic",
2895
            align_corners=True,
2896
        ),
2897
        cpp_options_args="""F::InterpolateFuncOptions()
2898
                            .size(std::vector<int64_t>({4, 6}))
2899
                            .scale_factor(c10::nullopt)
2900
                            .mode(torch::kBicubic)
2901
                            .align_corners(true)""",
2902
        input_size=(1, 2, 4, 4),
2903
        fullname="interpolate_bicubic_tuple_2d_align_corners",
2904
        pickle=False,
2905
    ),
2906
    dict(
2907
        constructor=wrap_functional(
2908
            F.interpolate,
2909
            size=None,
2910
            scale_factor=(2.0, 1.0),
2911
            mode="bicubic",
2912
            align_corners=True,
2913
        ),
2914
        cpp_options_args="""F::InterpolateFuncOptions()
2915
                            .size(c10::nullopt)
2916
                            .scale_factor(std::vector<double>({2., 1.}))
2917
                            .mode(torch::kBicubic)
2918
                            .align_corners(true)""",
2919
        input_size=(1, 2, 4, 4),
2920
        fullname="interpolate_bicubic_scale_tuple_skewed_2d_align_corners",
2921
        pickle=False,
2922
    ),
2923
    dict(
2924
        constructor=wrap_functional(
2925
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2926
        ),
2927
        cpp_options_args="""F::InterpolateFuncOptions()
2928
                            .size(std::vector<int64_t>({12, 12, 12}))
2929
                            .scale_factor(c10::nullopt)
2930
                            .mode(torch::kNearest)""",
2931
        input_size=(1, 2, 4, 4, 4),
2932
        fullname="interpolate_nearest_3d",
2933
        pickle=False,
2934
    ),
2935
    dict(
2936
        constructor=wrap_functional(
2937
            F.interpolate, size=12, scale_factor=None, mode="nearest"
2938
        ),
2939
        cpp_options_args="""F::InterpolateFuncOptions()
2940
                            .size(std::vector<int64_t>({12, 12, 12}))
2941
                            .scale_factor(c10::nullopt)
2942
                            .mode(torch::kNearest)""",
2943
        input_size=(0, 2, 4, 4, 4),
2944
        fullname="interpolate_nearest_3d_zero_dim",
2945
        pickle=False,
2946
    ),
2947
    dict(
2948
        constructor=wrap_functional(
2949
            F.interpolate, size=(12, 16, 16), scale_factor=None, mode="nearest"
2950
        ),
2951
        cpp_options_args="""F::InterpolateFuncOptions()
2952
                            .size(std::vector<int64_t>({12, 16, 16}))
2953
                            .scale_factor(c10::nullopt)
2954
                            .mode(torch::kNearest)""",
2955
        input_size=(1, 2, 3, 4, 4),
2956
        fullname="interpolate_nearest_tuple_3d",
2957
        pickle=False,
2958
    ),
2959
    dict(
2960
        constructor=wrap_functional(
2961
            F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2962
        ),
2963
        cpp_options_args="""F::InterpolateFuncOptions()
2964
                            .size(c10::nullopt)
2965
                            .scale_factor(std::vector<double>({4., 4., 4.}))
2966
                            .mode(torch::kNearest)""",
2967
        input_size=(1, 2, 4, 4, 4),
2968
        fullname="interpolate_nearest_scale_3d",
2969
        pickle=False,
2970
    ),
2971
    dict(
2972
        constructor=wrap_functional(
2973
            F.interpolate,
2974
            size=12,
2975
            scale_factor=None,
2976
            mode="trilinear",
2977
            align_corners=False,
2978
        ),
2979
        cpp_options_args="""F::InterpolateFuncOptions()
2980
                            .size(std::vector<int64_t>({12, 12, 12}))
2981
                            .scale_factor(c10::nullopt)
2982
                            .mode(torch::kTrilinear)
2983
                            .align_corners(false)""",
2984
        input_size=(1, 2, 4, 4, 4),
2985
        fullname="interpolate_trilinear_3d",
2986
        pickle=False,
2987
    ),
2988
    dict(
2989
        constructor=wrap_functional(
2990
            F.interpolate,
2991
            size=12,
2992
            scale_factor=None,
2993
            mode="trilinear",
2994
            align_corners=False,
2995
        ),
2996
        cpp_options_args="""F::InterpolateFuncOptions()
2997
                            .size(std::vector<int64_t>({12, 12, 12}))
2998
                            .scale_factor(c10::nullopt)
2999
                            .mode(torch::kTrilinear)
3000
                            .align_corners(false)""",
3001
        input_size=(0, 2, 4, 4, 4),
3002
        fullname="interpolate_trilinear_3d_zero_dim",
3003
        pickle=False,
3004
    ),
3005
    dict(
3006
        constructor=wrap_functional(
3007
            F.interpolate,
3008
            size=(4, 6, 6),
3009
            scale_factor=None,
3010
            mode="trilinear",
3011
            align_corners=False,
3012
        ),
3013
        cpp_options_args="""F::InterpolateFuncOptions()
3014
                            .size(std::vector<int64_t>({4, 6, 6}))
3015
                            .scale_factor(c10::nullopt)
3016
                            .mode(torch::kTrilinear)
3017
                            .align_corners(false)""",
3018
        input_size=(1, 2, 2, 3, 3),
3019
        fullname="interpolate_trilinear_tuple_3d",
3020
        pickle=False,
3021
    ),
3022
    dict(
3023
        constructor=wrap_functional(
3024
            F.interpolate,
3025
            size=None,
3026
            scale_factor=3.0,
3027
            mode="trilinear",
3028
            align_corners=False,
3029
        ),
3030
        cpp_options_args="""F::InterpolateFuncOptions()
3031
                            .size(c10::nullopt)
3032
                            .scale_factor(std::vector<double>({3., 3., 3.}))
3033
                            .mode(torch::kTrilinear)
3034
                            .align_corners(false)""",
3035
        input_size=(1, 2, 3, 4, 4),
3036
        fullname="interpolate_trilinear_scale_3d",
3037
        # See https://github.com/pytorch/pytorch/issues/5006
3038
        precision=3e-4,
3039
        pickle=False,
3040
    ),
3041
    dict(
3042
        constructor=wrap_functional(
3043
            F.interpolate,
3044
            size=(4, 6, 6),
3045
            scale_factor=None,
3046
            mode="trilinear",
3047
            align_corners=True,
3048
        ),
3049
        cpp_options_args="""F::InterpolateFuncOptions()
3050
                            .size(std::vector<int64_t>({4, 6, 6}))
3051
                            .scale_factor(c10::nullopt)
3052
                            .mode(torch::kTrilinear)
3053
                            .align_corners(true)""",
3054
        input_size=(1, 2, 2, 3, 3),
3055
        fullname="interpolate_trilinear_tuple_3d_align_corners",
3056
        pickle=False,
3057
    ),
3058
    dict(
3059
        constructor=wrap_functional(
3060
            F.interpolate,
3061
            size=None,
3062
            scale_factor=3.0,
3063
            mode="trilinear",
3064
            align_corners=True,
3065
        ),
3066
        cpp_options_args="""F::InterpolateFuncOptions()
3067
                            .size(c10::nullopt)
3068
                            .scale_factor(std::vector<double>({3., 3., 3.}))
3069
                            .mode(torch::kTrilinear)
3070
                            .align_corners(true)""",
3071
        input_size=(1, 2, 3, 4, 4),
3072
        fullname="interpolate_trilinear_scale_3d_align_corners",
3073
        # See https://github.com/pytorch/pytorch/issues/5006
3074
        precision=3e-4,
3075
        pickle=False,
3076
    ),
3077
    dict(
3078
        module_name="AdaptiveMaxPool1d",
3079
        constructor_args=(3,),
3080
        cpp_constructor_args="torch::nn::AdaptiveMaxPool1dOptions(3)",
3081
        input_fn=lambda: _rand_tensor_non_equal(1, 3, 5),
3082
    ),
3083
    dict(
3084
        module_name="AdaptiveMaxPool2d",
3085
        constructor_args=(3,),
3086
        cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions(3)",
3087
        input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3088
        desc="single",
3089
    ),
3090
    dict(
3091
        module_name="AdaptiveMaxPool2d",
3092
        constructor_args=((3, 4),),
3093
        cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions({3, 4})",
3094
        input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3095
        desc="tuple",
3096
    ),
3097
    dict(
3098
        module_name="AdaptiveMaxPool2d",
3099
        constructor_args=((3, None),),
3100
        cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions({3, c10::nullopt})",
3101
        input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3102
        desc="tuple_none",
3103
    ),
3104
    dict(
3105
        module_name="AdaptiveMaxPool3d",
3106
        constructor_args=(3,),
3107
        cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions(3)",
3108
        input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3109
        desc="single",
3110
    ),
3111
    dict(
3112
        module_name="AdaptiveMaxPool3d",
3113
        constructor_args=((3, 4, 5),),
3114
        cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})",
3115
        input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3116
        desc="tuple",
3117
    ),
3118
    dict(
3119
        module_name="AdaptiveMaxPool3d",
3120
        constructor_args=((3, None, 5),),
3121
        cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, c10::nullopt, 5})",
3122
        input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3123
        desc="tuple_none",
3124
    ),
3125
    dict(
3126
        module_name="AdaptiveMaxPool3d",
3127
        constructor_args=(3,),
3128
        cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions(3)",
3129
        input_fn=lambda: _rand_tensor_non_equal(2, 3, 12, 9, 3),
3130
        desc="single_nonatomic",
3131
    ),
3132
    dict(
3133
        module_name="AdaptiveMaxPool3d",
3134
        constructor_args=((3, 4, 5),),
3135
        cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})",
3136
        input_fn=lambda: _rand_tensor_non_equal(2, 3, 6, 4, 10),
3137
        desc="tuple_nonatomic",
3138
    ),
3139
    dict(
3140
        module_name="AdaptiveAvgPool1d",
3141
        constructor_args=(3,),
3142
        cpp_constructor_args="torch::nn::AdaptiveAvgPool1dOptions(3)",
3143
        input_fn=lambda: torch.rand(1, 3, 5),
3144
    ),
3145
    dict(
3146
        module_name="AdaptiveAvgPool1d",
3147
        constructor_args=(1,),
3148
        cpp_constructor_args="torch::nn::AdaptiveAvgPool1dOptions(1)",
3149
        input_fn=lambda: torch.rand(1, 3, 5),
3150
        desc="one_output",
3151
    ),
3152
    dict(
3153
        module_name="AdaptiveAvgPool2d",
3154
        constructor_args=(3,),
3155
        cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions(3)",
3156
        input_fn=lambda: torch.rand(1, 3, 5, 6),
3157
        desc="single",
3158
        check_gradgrad=False,
3159
    ),
3160
    dict(
3161
        module_name="AdaptiveAvgPool2d",
3162
        constructor_args=(1,),
3163
        cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions(1)",
3164
        input_fn=lambda: torch.rand(1, 3, 5, 6),
3165
        desc="single_1x1output",
3166
        check_gradgrad=False,
3167
    ),
3168
    dict(
3169
        module_name="AdaptiveAvgPool2d",
3170
        constructor_args=((3, 4),),
3171
        cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions({3, 4})",
3172
        input_fn=lambda: torch.rand(1, 3, 5, 6),
3173
        desc="tuple",
3174
        check_gradgrad=False,
3175
    ),
3176
    dict(
3177
        module_name="AdaptiveAvgPool2d",
3178
        constructor_args=((3, None),),
3179
        cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions({3, c10::nullopt})",
3180
        input_fn=lambda: torch.rand(1, 3, 5, 6),
3181
        desc="tuple_none",
3182
        check_gradgrad=False,
3183
    ),
3184
    dict(
3185
        module_name="AdaptiveAvgPool3d",
3186
        constructor_args=(3,),
3187
        cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions(3)",
3188
        input_fn=lambda: torch.rand(2, 3, 5, 2, 7),
3189
        desc="single",
3190
    ),
3191
    dict(
3192
        module_name="AdaptiveAvgPool3d",
3193
        constructor_args=((3, 4, 5),),
3194
        cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions({3, 4, 5})",
3195
        input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
3196
        desc="tuple",
3197
    ),
3198
    dict(
3199
        module_name="AdaptiveAvgPool3d",
3200
        constructor_args=((None, 4, 5),),
3201
        cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions({c10::nullopt, 4, 5})",
3202
        input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
3203
        desc="tuple_none",
3204
    ),
3205
    dict(module_name="SELU", input_size=(3, 2, 5), check_inplace=True),
3206
    dict(module_name="SELU", input_size=(), check_inplace=True, desc="scalar"),
3207
    dict(
3208
        module_name="CELU",
3209
        input_size=(3, 2, 5),
3210
        constructor_args=(2.0,),
3211
        cpp_constructor_args="torch::nn::CELUOptions().alpha(2.)",
3212
        check_inplace=True,
3213
        reference_fn=lambda x, *_: torch.where(x >= 0, x, 2.0 * ((0.5 * x).exp() - 1)),
3214
    ),
3215
    dict(
3216
        module_name="CELU",
3217
        input_size=(),
3218
        constructor_args=(2.0,),
3219
        cpp_constructor_args="torch::nn::CELUOptions().alpha(2.)",
3220
        check_inplace=True,
3221
        reference_fn=lambda x, *_: torch.where(x >= 0, x, 2.0 * ((0.5 * x).exp() - 1)),
3222
        desc="scalar",
3223
    ),
3224
    dict(
3225
        module_name="GLU",
3226
        input_size=(5, 6),
3227
    ),
3228
    dict(
3229
        module_name="GLU",
3230
        constructor_args=(1,),
3231
        cpp_constructor_args="torch::nn::GLUOptions(1)",
3232
        input_size=(5, 6, 7),
3233
        desc="dim",
3234
    ),
3235
    dict(
3236
        module_name="GELU",
3237
        input_size=(),
3238
        desc="scalar",
3239
        reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
3240
    ),
3241
    dict(
3242
        module_name="GELU",
3243
        input_size=(3, 2, 5),
3244
        reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
3245
    ),
3246
    dict(
3247
        constructor=wrap_functional(F.softmax, dim=-1),
3248
        cpp_options_args="F::SoftmaxFuncOptions(-1)",
3249
        input_size=(2, 128),  # trigger the last-dim algo in CUDA
3250
        fullname="softmax_lastdim",
3251
        pickle=False,
3252
    ),
3253
    dict(
3254
        constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
3255
        cpp_options_args="F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)",
3256
        input_size=(2, 128),
3257
        fullname="softmax_lastdim_dtype",
3258
        pickle=False,
3259
        test_cuda=False,
3260
    ),
3261
    dict(
3262
        constructor=wrap_functional(F.softmax, dim=1),
3263
        cpp_options_args="F::SoftmaxFuncOptions(1)",
3264
        input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
3265
        fullname="softmax_spatial_special",
3266
        pickle=False,
3267
        test_cuda=(not TEST_WITH_ROCM),
3268
    ),
3269
    dict(
3270
        constructor=wrap_functional(F.softmax, dim=1),
3271
        cpp_options_args="F::SoftmaxFuncOptions(1)",
3272
        input_size=(2, 2, 4, 4),  # regular spatial algorithm
3273
        fullname="softmax_spatial",
3274
        pickle=False,
3275
    ),
3276
    dict(
3277
        constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
3278
        cpp_options_args="F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)",
3279
        input_size=(2, 2, 4, 4),  # regular spatial algorithm
3280
        fullname="softmax_spatial_dtype",
3281
        pickle=False,
3282
        test_cuda=False,
3283
    ),
3284
    dict(
3285
        constructor=wrap_functional(F.softmax, dim=0),
3286
        cpp_options_args="F::SoftmaxFuncOptions(0)",
3287
        input_size=(2, 3, 4, 5),
3288
        fullname="softmax_functional_dim0",
3289
        test_cuda=False,
3290
        pickle=False,
3291
    ),
3292
    dict(
3293
        constructor=wrap_functional(F.softmax, dim=3),
3294
        cpp_options_args="F::SoftmaxFuncOptions(3)",
3295
        input_size=(2, 3, 4, 5),
3296
        fullname="softmax_functional_dim3",
3297
        test_cuda=False,
3298
        pickle=False,
3299
    ),
3300
    dict(
3301
        constructor=wrap_functional(F.softmax, dim=-1),
3302
        cpp_options_args="F::SoftmaxFuncOptions(-1)",
3303
        input_size=(),
3304
        fullname="softmax_functional_scalar",
3305
        test_cuda=False,
3306
        pickle=False,
3307
    ),
3308
    dict(
3309
        constructor=wrap_functional(F.log_softmax, dim=-1),
3310
        cpp_options_args="F::LogSoftmaxFuncOptions(-1)",
3311
        input_size=(2, 128),  # trigger the last-dim algo in CUDA
3312
        fullname="log_softmax_lastdim",
3313
        pickle=False,
3314
    ),
3315
    dict(
3316
        constructor=wrap_functional(F.log_softmax, dim=1),
3317
        cpp_options_args="F::LogSoftmaxFuncOptions(1)",
3318
        input_size=(2, 128, 2, 2),  # trigger special case of spatial CUDA algo
3319
        fullname="log_softmax_spatial_special",
3320
        pickle=False,
3321
        test_cuda=(not TEST_WITH_ROCM),
3322
    ),
3323
    dict(
3324
        constructor=wrap_functional(F.log_softmax, dim=1),
3325
        cpp_options_args="F::LogSoftmaxFuncOptions(1)",
3326
        input_size=(2, 2, 4, 4),  # regular spatial algorithm
3327
        fullname="log_softmax_spatial",
3328
        pickle=False,
3329
    ),
3330
    dict(
3331
        constructor=wrap_functional(F.log_softmax, dim=0),
3332
        cpp_options_args="F::LogSoftmaxFuncOptions(0)",
3333
        input_size=(2, 3, 4, 5),
3334
        fullname="log_softmax_dim0",
3335
        pickle=False,
3336
    ),
3337
    dict(
3338
        constructor=wrap_functional(F.log_softmax, dim=3),
3339
        cpp_options_args="F::LogSoftmaxFuncOptions(3)",
3340
        input_size=(2, 3, 4, 5),
3341
        fullname="log_softmax_dim3",
3342
        pickle=False,
3343
    ),
3344
    dict(
3345
        constructor=wrap_functional(F.log_softmax, dim=0),
3346
        cpp_options_args="F::LogSoftmaxFuncOptions(0)",
3347
        input_size=(),
3348
        fullname="log_softmax_scalar",
3349
        pickle=False,
3350
    ),
3351
    dict(
3352
        fullname="Unfold",
3353
        constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
3354
        cpp_constructor_args="torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})",
3355
        input_size=(2, 4, 3, 3),
3356
        check_gradgrad=False,
3357
        test_cuda=True,
3358
    ),
3359
    dict(
3360
        fullname="Fold",
3361
        constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
3362
        cpp_constructor_args="torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})",
3363
        input_size=(2, 16, 4),
3364
        check_gradgrad=False,
3365
        test_cuda=True,
3366
    ),
3367
    dict(
3368
        fullname="Unfold_int_input",
3369
        constructor=lambda: nn.Unfold(2, 1, 0, 1),
3370
        cpp_constructor_args="torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)",
3371
        input_size=(2, 4, 3, 3),
3372
        check_gradgrad=False,
3373
        test_cuda=True,
3374
    ),
3375
    dict(
3376
        fullname="Fold_int_input",
3377
        constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
3378
        cpp_constructor_args="torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)",
3379
        input_size=(2, 16, 4),
3380
        check_gradgrad=False,
3381
        test_cuda=True,
3382
    ),
3383
    dict(
3384
        module_name="Threshold",
3385
        constructor_args=(2.0, 1.0),
3386
        cpp_constructor_args="torch::nn::ThresholdOptions(2., 1.)",
3387
        input_size=(),
3388
        check_inplace=True,
3389
        desc="threshold_value_scalar",
3390
    ),
3391
    dict(module_name="ReLU", input_size=(), check_inplace=True, desc="scalar"),
3392
    dict(module_name="ReLU6", input_size=(), check_inplace=True, desc="scalar"),
3393
    dict(
3394
        module_name="RReLU",
3395
        constructor_args=(0.1, 0.9),
3396
        cpp_constructor_args="torch::nn::RReLUOptions().lower(0.1).upper(0.9)",
3397
        input_size=(),
3398
        desc="with_up_down_scalar",
3399
        test_cuda=False,
3400
    ),
3401
    dict(
3402
        module_name="Hardtanh",
3403
        input_size=(),
3404
        reference_fn=lambda i, *_: i.clamp(-1, 1),
3405
        desc="scalar",
3406
    ),
3407
    dict(
3408
        module_name="Sigmoid",
3409
        input_size=(),
3410
        desc="scalar",
3411
    ),
3412
    dict(
3413
        module_name="Tanh",
3414
        input_size=(),
3415
        desc="scalar",
3416
    ),
3417
    dict(
3418
        module_name="Softmax",
3419
        constructor_args=(0,),
3420
        cpp_constructor_args="torch::nn::SoftmaxOptions(0)",
3421
        input_size=(),
3422
        reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(0, True)),
3423
        desc="scalar",
3424
    ),
3425
    dict(
3426
        module_name="LogSoftmax",
3427
        constructor_args=(0,),
3428
        cpp_constructor_args="torch::nn::LogSoftmaxOptions(0)",
3429
        input_size=(),
3430
        reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
3431
        desc="multiparam_scalar",
3432
    ),
3433
    dict(
3434
        module_name="ELU",
3435
        constructor_args=(2.0,),
3436
        cpp_constructor_args="torch::nn::ELUOptions().alpha(2.)",
3437
        input_size=(),
3438
        desc="scalar",
3439
    ),
3440
    dict(
3441
        module_name="Hardshrink",
3442
        constructor_args=(2.0,),
3443
        cpp_constructor_args="torch::nn::HardshrinkOptions(2.)",
3444
        input_size=(),
3445
        desc="scalar",
3446
    ),
3447
    dict(
3448
        module_name="LeakyReLU",
3449
        constructor_args=(0.5,),
3450
        cpp_constructor_args="torch::nn::LeakyReLUOptions().negative_slope(0.5)",
3451
        input_size=(),
3452
        check_inplace=True,
3453
        desc="with_negval_scalar",
3454
    ),
3455
    dict(
3456
        module_name="LogSigmoid",
3457
        input_size=(),
3458
        reference_fn=lambda i, *_: i.sigmoid().log(),
3459
        desc="scalar",
3460
    ),
3461
    dict(
3462
        module_name="Softplus",
3463
        constructor_args=(2, -100),
3464
        cpp_constructor_args="torch::nn::SoftplusOptions().beta(2).threshold(-100)",
3465
        input_size=(),
3466
        reference_fn=(
3467
            lambda i, *_: ((i * 2) > -100).type_as(i) * i
3468
            + ((i * 2) <= -100).type_as(i) * 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i))
3469
        ),
3470
        desc="beta_threshold_scalar",
3471
    ),
3472
    dict(
3473
        module_name="Softshrink",
3474
        constructor_args=(1,),
3475
        cpp_constructor_args="torch::nn::SoftshrinkOptions(1)",
3476
        input_size=(),
3477
        desc="lambda_scalar",
3478
    ),
3479
    dict(
3480
        module_name="PReLU",
3481
        input_size=(),
3482
        reference_fn=lambda i, p, _: torch.clamp(i, min=0)
3483
        + torch.clamp(i, max=0) * p[0][0],
3484
        desc="scalar",
3485
    ),
3486
    dict(
3487
        module_name="Softsign",
3488
        input_size=(),
3489
        reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
3490
        desc="scalar",
3491
    ),
3492
    dict(
3493
        module_name="Softmin",
3494
        constructor_args=(0,),
3495
        cpp_constructor_args="torch::nn::SoftminOptions(0)",
3496
        input_size=(),
3497
        desc="scalar",
3498
    ),
3499
    dict(
3500
        module_name="Tanhshrink",
3501
        input_size=(),
3502
        desc="scalar",
3503
    ),
3504
    dict(
3505
        fullname="Padding12_1dcircular",
3506
        constructor=wrap_functional(F.pad, pad=(1, 2), mode="circular"),
3507
        cpp_options_args="F::PadFuncOptions({1, 2}).mode(torch::kCircular)",
3508
        input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3509
        reference_fn=lambda i, *_: padding1d_circular(i, (1, 2)),
3510
        skip_double=TEST_WITH_ROCM,
3511
        pickle=False,
3512
    ),
3513
    dict(
3514
        fullname="Padding31_1dcircular",
3515
        constructor=wrap_functional(F.pad, pad=(3, 1), mode="circular"),
3516
        cpp_options_args="F::PadFuncOptions({3, 1}).mode(torch::kCircular)",
3517
        input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3518
        reference_fn=lambda i, *_: padding1d_circular(i, (3, 1)),
3519
        skip_double=TEST_WITH_ROCM,
3520
        pickle=False,
3521
    ),
3522
    dict(
3523
        fullname="Padding33_1dcircular",
3524
        constructor=wrap_functional(F.pad, pad=(3, 3), mode="circular"),
3525
        cpp_options_args="F::PadFuncOptions({3, 3}).mode(torch::kCircular)",
3526
        input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3527
        reference_fn=lambda i, *_: padding1d_circular(i, (3, 3)),
3528
        skip_double=TEST_WITH_ROCM,
3529
        pickle=False,
3530
    ),
3531
    dict(
3532
        fullname="Padding1221_2dcircular",
3533
        constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode="circular"),
3534
        cpp_options_args="F::PadFuncOptions({1, 2, 2, 1}).mode(torch::kCircular)",
3535
        input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape(
3536
            [1, 1, 2, 3]
3537
        ),
3538
        reference_fn=lambda i, *_: padding2d_circular(i, (1, 2, 2, 1)),
3539
        skip_double=TEST_WITH_ROCM,
3540
        pickle=False,
3541
    ),
3542
    dict(
3543
        fullname="Padding2322_2dcircular",
3544
        constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode="circular"),
3545
        cpp_options_args="F::PadFuncOptions({2, 3, 2, 2}).mode(torch::kCircular)",
3546
        input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape(
3547
            [1, 1, 2, 3]
3548
        ),
3549
        reference_fn=lambda i, *_: padding2d_circular(i, (2, 3, 2, 2)),
3550
        skip_double=TEST_WITH_ROCM,
3551
        pickle=False,
3552
    ),
3553
    dict(
3554
        fullname="Padding3331_2dcircular",
3555
        constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode="circular"),
3556
        cpp_options_args="F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)",
3557
        input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape(
3558
            [1, 1, 3, 3]
3559
        ),
3560
        reference_fn=lambda i, *_: padding2d_circular(i, (3, 3, 3, 1)),
3561
        skip_double=TEST_WITH_ROCM,
3562
        pickle=False,
3563
    ),
3564
    dict(
3565
        fullname="Padding122112_3dcircular",
3566
        constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode="circular"),
3567
        cpp_options_args="F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kCircular)",
3568
        input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3569
            [1, 1, 2, 2, 3]
3570
        ),
3571
        reference_fn=lambda i, *_: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
3572
        skip_double=TEST_WITH_ROCM,
3573
        pickle=False,
3574
    ),
3575
    dict(
3576
        fullname="Padding322112_3dcircular",
3577
        constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode="circular"),
3578
        cpp_options_args="F::PadFuncOptions({3, 2, 2, 1, 1, 2}).mode(torch::kCircular)",
3579
        input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3580
            [1, 1, 2, 2, 3]
3581
        ),
3582
        reference_fn=lambda i, *_: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
3583
        skip_double=TEST_WITH_ROCM,
3584
        pickle=False,
3585
    ),
3586
    dict(
3587
        fullname="Padding332122_3dcircular",
3588
        constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode="circular"),
3589
        cpp_options_args="F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)",
3590
        input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3591
            [1, 1, 2, 2, 3]
3592
        ),
3593
        reference_fn=lambda i, *_: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
3594
        skip_double=TEST_WITH_ROCM,
3595
        pickle=False,
3596
    ),
3597
]
3598

3599
# add conv padding mode tests:
3600
for padding_mode, cpp_padding_mode in zip(
3601
    ["reflect", "circular", "replicate", "zeros"],
3602
    ["torch::kReflect", "torch::kCircular", "torch::kReplicate", "torch::kZeros"],
3603
):
3604
    # conv signature:
3605
    #     in_channels, out_channels, kernel_size, stride=1,
3606
    #     padding=0, dilation=1, groups=1,
3607
    #     bias=True, padding_mode='zeros'
3608
    for d in (1, 2, 3):
3609
        if d == 3 and padding_mode == "reflect":
3610
            # FIXME: remove after implementing reflection pad 3d
3611
            #        https://github.com/pytorch/pytorch/issues/27655
3612
            continue
3613
        new_module_tests.append(
3614
            dict(
3615
                module_name="Conv{}d".format(d),
3616
                constructor_args=(3, 4, 3, 2, 2, 1, 1, True, padding_mode),
3617
                cpp_constructor_args="""torch::nn::Conv{}dOptions(3, 4, 3)
3618
                                        .stride(2)
3619
                                        .padding(2)
3620
                                        .dilation(1)
3621
                                        .groups(1)
3622
                                        .bias(true)
3623
                                        .padding_mode({})""".format(
3624
                    d, cpp_padding_mode
3625
                ),
3626
                input_size=(2, 3) + (3,) * d,
3627
                output_size=(2, 4) + (3,) * d,
3628
                cudnn=True,
3629
                desc="{}_stride2_pad2".format(padding_mode),
3630
            ),
3631
        )
3632

3633

3634
def kldivloss_reference(input, target, reduction="mean"):
3635
    safe_target = target * (target > 0).type_as(target)
3636
    safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
3637
    result = safe_target * (safe_target_log - input)
3638
    if reduction == "mean":
3639
        return result.mean()
3640
    elif reduction == "sum":
3641
        return result.sum()
3642
    elif reduction == "batchmean" and results.dim() != 0:
3643
        return result.sum() / result.size(0)
3644
    return result
3645

3646

3647
def nlllossNd_reference(
3648
    input, target, weight=None, ignore_index=-100, reduction="mean"
3649
):
3650
    assert input.dim() >= 3
3651
    N = input.size(0)
3652
    C = input.size(1)
3653
    out_size = (N,) + input.size()[2:]
3654
    output = torch.zeros(out_size).type_as(input)
3655

3656
    if weight is None:
3657
        weight = torch.ones(C).type_as(input)
3658
    total_weight = 0
3659
    for tup in product(*[range(size) for size in out_size]):
3660
        t_nx = target[tup]
3661
        norm = 0.0 if ignore_index == t_nx else weight[t_nx].item()
3662
        input_index = list(tup)
3663
        input_index.insert(1, t_nx)
3664
        output[tup] = -input[tuple(input_index)] * norm
3665
        total_weight += norm
3666

3667
    if reduction == "mean":
3668
        return output.sum() / total_weight
3669
    elif reduction == "sum":
3670
        return output.sum()
3671
    return output
3672

3673

3674
def nllloss_reference(input, target, weight=None, ignore_index=-100, reduction="mean"):
3675
    def nll_loss_helper(input, target, weight, ignore_index):
3676
        if target == ignore_index:
3677
            return (0, 0)
3678
        norm = 1 if weight is None else weight[target]
3679
        result = -input[target] * norm
3680
        return (result, norm)
3681

3682
    losses_and_weights = [
3683
        nll_loss_helper(i, t, weight, ignore_index) for i, t in zip(input, target)
3684
    ]
3685
    losses, weights = zip(*losses_and_weights)
3686
    losses_tensor = input.new_tensor(losses)
3687
    if reduction == "mean":
3688
        return sum(losses_tensor) / sum(weights)
3689
    elif reduction == "sum":
3690
        return sum(losses_tensor)
3691
    else:
3692
        return losses_tensor
3693

3694

3695
def smoothl1loss_reference(input, target, reduction="mean"):
3696
    abs_diff = (input - target).abs()
3697
    ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
3698
    lt_one_mask = (abs_diff < 1).type_as(abs_diff)
3699
    output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff**2)
3700
    if reduction == "mean":
3701
        return output.mean()
3702
    elif reduction == "sum":
3703
        return output.sum()
3704
    return output
3705

3706

3707
def _multilabelmarginloss_reference(input, target):
3708
    targets = []
3709
    for target_index in target:
3710
        if target_index < 0:
3711
            break
3712
        targets.append(target_index)
3713

3714
    sum = 0
3715
    for target_index in targets:
3716
        for i in range(0, len(input)):
3717
            if i not in targets:
3718
                sum += max(0, 1 - input[target_index] + input[i])
3719

3720
    return sum
3721

3722

3723
def multilabelmarginloss_reference(input, target, reduction="mean"):
3724
    # make everything 2-dimensional
3725
    input_dim = input.dim()
3726
    if input.dim() < 2:
3727
        assert target.dim() < 2
3728
        input = (
3729
            input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
3730
        )
3731
        target = (
3732
            target.unsqueeze(0)
3733
            if target.dim() == 1
3734
            else target.unsqueeze(0).unsqueeze(0)
3735
        )
3736

3737
    n = input.size(0)
3738
    dim = input.size(1)
3739
    output = input.new(n).zero_()
3740
    for i in range(0, n):
3741
        output[i] = _multilabelmarginloss_reference(input[i], target[i])
3742

3743
    if reduction == "mean":
3744
        return output.mean() / dim
3745
    elif reduction == "sum":
3746
        return output.sum() / dim
3747
    elif input_dim < 2:
3748
        # we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
3749
        # back to correct dimensionality
3750
        return output.squeeze() / dim
3751
    else:
3752
        return output / dim
3753

3754

3755
def hingeembeddingloss_reference(input, target, margin=1.0, reduction="mean"):
3756
    margin_clamp = (margin - input).clamp(min=0).type_as(input)
3757
    output = torch.where(target == 1, input, margin_clamp)
3758

3759
    if reduction == "mean":
3760
        return output.mean()
3761
    elif reduction == "sum":
3762
        return output.sum()
3763
    return output
3764

3765

3766
def softmarginloss_reference(input, target, reduction="mean"):
3767
    output = (1 + (-input * target).exp()).log()
3768

3769
    if reduction == "mean":
3770
        return output.mean()
3771
    elif reduction == "sum":
3772
        return output.sum()
3773
    return output
3774

3775

3776
def _multimarginloss_reference(input, target_idx, p, margin, weight):
3777
    if weight is None:
3778
        weight = input.new(len(input)).fill_(1)
3779

3780
    output = 0
3781
    for i in range(0, len(input)):
3782
        if i != target_idx:
3783
            output += max(
3784
                0, weight[target_idx] * (margin - input[target_idx] + input[i]) ** p
3785
            )
3786
    return output
3787

3788

3789
def multimarginloss_reference(
3790
    input, target, p=1, margin=1, weight=None, reduction="mean"
3791
):
3792
    if input.dim() < 2:
3793
        input = (
3794
            input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
3795
        )
3796

3797
    target_dim = target.dim()
3798
    if target.dim() == 0:
3799
        target = target.unsqueeze(0)
3800

3801
    n = input.size(0)
3802
    dim = input.size(1)
3803
    output = input.new(n)
3804
    for x in range(0, n):
3805
        output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
3806

3807
    if reduction == "mean":
3808
        return output.mean() / dim
3809
    elif reduction == "sum":
3810
        return output.sum() / dim
3811
    elif target_dim == 0:
3812
        return output.squeeze(0) / dim
3813
    return output / dim
3814

3815

3816
def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction="mean"):
3817
    def _cos(a, b):
3818
        cos = a.new(a.size(0))
3819
        for i in range(0, a.size(0)):
3820
            cos[i] = (a[i] * b[i]).sum() / (
3821
                (((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5
3822
            )
3823
        return cos
3824

3825
    output = torch.where(
3826
        target == 1,
3827
        1 - _cos(input1, input2),
3828
        (_cos(input1, input2) - margin).clamp(min=0),
3829
    )
3830

3831
    if reduction == "mean":
3832
        return output.mean()
3833
    elif reduction == "sum":
3834
        return output.sum()
3835
    return output
3836

3837

3838
def tripletmarginloss_reference(
3839
    anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, reduction="mean"
3840
):
3841
    d_p = torch.pairwise_distance(anchor, positive, p, eps)
3842
    d_n = torch.pairwise_distance(anchor, negative, p, eps)
3843
    if swap:
3844
        d_s = torch.pairwise_distance(positive, negative, p, eps)
3845
        d_n = torch.min(d_n, d_s)
3846

3847
    output = torch.clamp(margin + d_p - d_n, min=0.0)
3848
    if reduction == "mean":
3849
        return output.mean()
3850
    elif reduction == "sum":
3851
        return output.sum()
3852
    return output
3853

3854

3855
def marginrankingloss_reference(input1, input2, target, margin=0, reduction="mean"):
3856
    output = (-target * (input1 - input2) + margin).clamp(min=0)
3857
    if reduction == "mean":
3858
        return output.mean()
3859
    elif reduction == "sum":
3860
        return output.sum()
3861
    return output
3862

3863

3864
# this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space
3865
def ctcloss_reference(
3866
    log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean"
3867
):
3868
    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
3869
    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
3870
    dt = log_probs.dtype
3871
    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
3872
    targets = targets.long()
3873
    cum_target_lengths = target_lengths.cumsum(0)
3874
    losses = []
3875
    for i in range(log_probs.size(1)):
3876
        input_length = input_lengths[i].item()
3877
        target_length = target_lengths[i].item()
3878
        cum_target_length = cum_target_lengths[i].item()
3879
        targets_prime = targets.new_full((2 * target_length + 1,), blank)
3880
        if targets.dim() == 2:
3881
            targets_prime[1::2] = targets[i, :target_length]
3882
        else:
3883
            targets_prime[1::2] = targets[
3884
                cum_target_length - target_length : cum_target_length
3885
            ]
3886
        probs = log_probs[:input_length, i].exp()
3887
        alpha = log_probs.new_zeros((target_length * 2 + 1,))
3888
        alpha[0] = probs[0, blank]
3889
        alpha[1] = probs[0, targets_prime[1]]
3890
        mask_third = targets_prime[:-2] != targets_prime[2:]
3891
        for t in range(1, input_length):
3892
            alpha_next = alpha.clone()
3893
            alpha_next[1:] += alpha[:-1]
3894
            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
3895
            alpha = probs[t, targets_prime] * alpha_next
3896
        losses.append(-alpha[-2:].sum().log()[None])
3897
    output = torch.cat(losses, 0)
3898
    if reduction == "mean":
3899
        return (
3900
            output / target_lengths.to(dtype=output.dtype, device=output.device)
3901
        ).mean()
3902
    elif reduction == "sum":
3903
        return output.sum()
3904
    output = output.to(dt)
3905
    return output
3906

3907

3908
def padding1d_circular(input, pad):
3909
    r"""input:
3910
      [[[0., 1., 2.],
3911
        [3., 4., 5.]]]
3912
    pad: (1, 2)
3913
    output:
3914
      [[[2., 0., 1., 2., 0., 1.],
3915
        [5., 3., 4., 5., 3., 4.]]]
3916
    """
3917
    return torch.cat([input[:, :, -pad[0] :], input, input[:, :, 0 : pad[1]]], dim=2)
3918

3919

3920
def padding2d_circular(input, pad):
3921
    r"""input:
3922
             [[[[0., 1., 2],
3923
                [3., 4., 5.]]]]
3924
            pad: (1, 2, 2, 1)
3925
    output:
3926
        [[[[2., 0., 1., 2., 0., 1.],
3927
           [5., 3., 4., 5., 3., 4.],
3928
           [2., 0., 1., 2., 0., 1.],
3929
           [5., 3., 4., 5., 3., 4.],
3930
           [2., 0., 1., 2., 0., 1.]]]]
3931
    """
3932
    input = torch.cat([input[:, :, -pad[2] :], input, input[:, :, 0 : pad[3]]], dim=2)
3933
    return torch.cat(
3934
        [input[:, :, :, -pad[0] :], input, input[:, :, :, 0 : pad[1]]], dim=3
3935
    )
3936

3937

3938
def padding3d_circular(input, pad):
3939
    r"""input:
3940
        [[[[[ 0.,  1.,  2.],
3941
            [ 3.,  4.,  5.]],
3942
           [[ 6.,  7.,  8.],
3943
            [ 9., 10., 11.]]]]]
3944
    pad: (1, 2, 2, 1, 1, 2)
3945
    output: [[[[[ 8.,  6.,  7.,  8.,  6.,  7.],
3946
           [11.,  9., 10., 11.,  9., 10.],
3947
           [ 8.,  6.,  7.,  8.,  6.,  7.],
3948
           [11.,  9., 10., 11.,  9., 10.],
3949
           [ 8.,  6.,  7.,  8.,  6.,  7.]],
3950

3951
          [[ 2.,  0.,  1.,  2.,  0.,  1.],
3952
           [ 5.,  3.,  4.,  5.,  3.,  4.],
3953
           [ 2.,  0.,  1.,  2.,  0.,  1.],
3954
           [ 5.,  3.,  4.,  5.,  3.,  4.],
3955
           [ 2.,  0.,  1.,  2.,  0.,  1.]],
3956

3957
          [[ 8.,  6.,  7.,  8.,  6.,  7.],
3958
           [11.,  9., 10., 11.,  9., 10.],
3959
           [ 8.,  6.,  7.,  8.,  6.,  7.],
3960
           [11.,  9., 10., 11.,  9., 10.],
3961
           [ 8.,  6.,  7.,  8.,  6.,  7.]],
3962

3963
          [[ 2.,  0.,  1.,  2.,  0.,  1.],
3964
           [ 5.,  3.,  4.,  5.,  3.,  4.],
3965
           [ 2.,  0.,  1.,  2.,  0.,  1.],
3966
           [ 5.,  3.,  4.,  5.,  3.,  4.],
3967
           [ 2.,  0.,  1.,  2.,  0.,  1.]],
3968

3969
          [[ 8.,  6.,  7.,  8.,  6.,  7.],
3970
           [11.,  9., 10., 11.,  9., 10.],
3971
           [ 8.,  6.,  7.,  8.,  6.,  7.],
3972
           [11.,  9., 10., 11.,  9., 10.],
3973
           [ 8.,  6.,  7.,  8.,  6.,  7.]]]]]
3974
    """
3975
    input = torch.cat([input[:, :, -pad[4] :], input, input[:, :, 0 : pad[5]]], dim=2)
3976
    input = torch.cat(
3977
        [input[:, :, :, -pad[2] :], input, input[:, :, :, 0 : pad[3]]], dim=3
3978
    )
3979
    return torch.cat(
3980
        [input[:, :, :, :, -pad[0] :], input, input[:, :, :, :, 0 : pad[1]]], dim=4
3981
    )
3982

3983

3984
loss_reference_fns = {
3985
    "KLDivLoss": kldivloss_reference,
3986
    "NLLLoss": nllloss_reference,
3987
    "NLLLossNd": nlllossNd_reference,
3988
    "SmoothL1Loss": smoothl1loss_reference,
3989
    "MultiLabelMarginLoss": multilabelmarginloss_reference,
3990
    "HingeEmbeddingLoss": hingeembeddingloss_reference,
3991
    "SoftMarginLoss": softmarginloss_reference,
3992
    "MultiMarginLoss": multimarginloss_reference,
3993
    "CosineEmbeddingLoss": cosineembeddingloss_reference,
3994
    "TripletMarginLoss": tripletmarginloss_reference,
3995
    "MarginRankingLoss": marginrankingloss_reference,
3996
    "CTCLoss": ctcloss_reference,
3997
}
3998

3999

4000
criterion_tests = [
4001
    dict(
4002
        module_name="L1Loss",
4003
        input_size=(2, 3, 4),
4004
        target_size=(2, 3, 4),
4005
        reference_fn=lambda i, t, _: 1.0
4006
        / i.numel()
4007
        * sum((a - b).abs().sum() for a, b in zip(i, t)),
4008
    ),
4009
    dict(
4010
        module_name="NLLLoss",
4011
        input_fn=lambda: torch.rand(15, 10).log(),
4012
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4013
        reference_fn=lambda i, t, m: nllloss_reference(
4014
            i, t, reduction=get_reduction(m)
4015
        ),
4016
        check_sum_reduction=True,
4017
        check_bfloat16=TEST_WITH_ROCM,
4018
    ),
4019
    dict(
4020
        module_name="NLLLoss",
4021
        constructor_args=(None, None, 2),
4022
        cpp_constructor_args="torch::nn::NLLLossOptions().weight({}).ignore_index(2)",
4023
        input_fn=lambda: torch.rand(15, 10).log(),
4024
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4025
        reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
4026
        desc="ignore_index",
4027
        check_bfloat16=TEST_WITH_ROCM,
4028
    ),
4029
    dict(
4030
        module_name="NLLLoss",
4031
        constructor_args_fn=lambda: (torch.rand(10),),
4032
        cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10))",
4033
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4034
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4035
        reference_fn=lambda i, t, m: nllloss_reference(i, t, weight=get_weight(m)),
4036
        desc="weights",
4037
        check_bfloat16=TEST_WITH_ROCM,
4038
    ),
4039
    dict(
4040
        module_name="NLLLoss",
4041
        constructor_args_fn=lambda: (torch.rand(10), None, 2),
4042
        cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)",
4043
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4044
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4045
        reference_fn=lambda i, t, m: nllloss_reference(
4046
            i, t, weight=get_weight(m), ignore_index=2
4047
        ),
4048
        desc="weights_ignore_index",
4049
        check_bfloat16=TEST_WITH_ROCM,
4050
    ),
4051
    dict(
4052
        module_name="NLLLoss",
4053
        constructor_args_fn=lambda: (torch.rand(10), None, -1),
4054
        cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)",
4055
        input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4056
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
4057
        reference_fn=lambda i, t, m: nllloss_reference(
4058
            i, t, weight=get_weight(m), ignore_index=-1
4059
        ),
4060
        desc="weights_ignore_index_neg",
4061
        check_bfloat16=TEST_WITH_ROCM,
4062
    ),
4063
    dict(
4064
        module_name="KLDivLoss",
4065
        input_fn=lambda: torch.rand(10, 10).log(),
4066
        target_fn=lambda: torch.rand(10, 10),
4067
        reference_fn=lambda i, t, m: kldivloss_reference(i, t, get_reduction(m)),
4068
        check_sum_reduction=True,
4069
    ),
4070
    dict(
4071
        module_name="MSELoss",
4072
        input_size=(2, 3, 4, 5),
4073
        target_size=(2, 3, 4, 5),
4074
        reference_fn=lambda i, t, m: (
4075
            (i - t).abs().pow(2).sum()
4076
            / (i.numel() if get_reduction(m) == "mean" else 1)
4077
        ),
4078
        check_sum_reduction=True,
4079
    ),
4080
    dict(
4081
        module_name="BCELoss",
4082
        input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4083
        target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4084
        reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum()
4085
        / (i.numel() if get_reduction(m) else 1),
4086
        check_gradgrad=False,
4087
        check_bfloat16=TEST_WITH_ROCM,
4088
    ),
4089
    dict(
4090
        module_name="BCELoss",
4091
        constructor_args_fn=lambda: (torch.rand(10),),
4092
        cpp_constructor_args="torch::nn::BCELossOptions().weight(torch::rand(10))",
4093
        input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4094
        target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4095
        reference_fn=lambda i, t, m: -(
4096
            (t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)
4097
        ).sum()
4098
        / (i.numel() if get_reduction(m) else 1),
4099
        desc="weights",
4100
        check_gradgrad=False,
4101
        check_bfloat16=TEST_WITH_ROCM,
4102
    ),
4103
    dict(
4104
        module_name="CrossEntropyLoss",
4105
        input_size=(15, 10),
4106
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4107
    ),
4108
    dict(
4109
        module_name="CrossEntropyLoss",
4110
        constructor_args_fn=lambda: (torch.rand(10),),
4111
        cpp_constructor_args="torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))",
4112
        input_size=(15, 10),
4113
        target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4114
        desc="weights",
4115
    ),
4116
    dict(
4117
        module_name="HingeEmbeddingLoss",
4118
        input_size=(10,),
4119
        target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
4120
        reference_fn=lambda i, t, m: hingeembeddingloss_reference(
4121
            i, t, reduction=get_reduction(m)
4122
        ),
4123
        check_sum_reduction=True,
4124
    ),
4125
    dict(
4126
        module_name="HingeEmbeddingLoss",
4127
        constructor_args=(0.5,),
4128
        cpp_constructor_args="torch::nn::HingeEmbeddingLossOptions().margin(0.5)",
4129
        input_size=(10,),
4130
        target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
4131
        reference_fn=lambda i, t, m: hingeembeddingloss_reference(
4132
            i, t, margin=0.5, reduction=get_reduction(m)
4133
        ),
4134
        desc="margin",
4135
        check_sum_reduction=True,
4136
    ),
4137
    dict(
4138
        module_name="MultiLabelMarginLoss",
4139
        input_size=(10,),
4140
        target_fn=lambda: torch.rand(10).mul(10).floor().long(),
4141
        reference_fn=lambda i, t, m: multilabelmarginloss_reference(
4142
            i, t, reduction=get_reduction(m)
4143
        ),
4144
        desc="1d",
4145
        check_sum_reduction=True,
4146
        check_gradgrad=False,
4147
        check_bfloat16=TEST_WITH_ROCM,
4148
    ),
4149
    dict(
4150
        module_name="MultiLabelMarginLoss",
4151
        input_size=(5, 10),
4152
        target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
4153
        reference_fn=lambda i, t, m: multilabelmarginloss_reference(
4154
            i, t, reduction=get_reduction(m)
4155
        ),
4156
        check_sum_reduction=True,
4157
        check_gradgrad=False,
4158
        check_bfloat16=TEST_WITH_ROCM,
4159
    ),
4160
    dict(
4161
        module_name="MultiLabelSoftMarginLoss",
4162
        input_size=(5, 10),
4163
        target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
4164
        reference_fn=lambda i, t, m: -(
4165
            t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
4166
        ).sum()
4167
        / i.numel(),
4168
        check_gradgrad=False,
4169
    ),
4170
    dict(
4171
        module_name="MultiMarginLoss",
4172
        input_size=(5, 10),
4173
        target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4174
        reference_fn=lambda i, t, m: multimarginloss_reference(
4175
            i, t, reduction=get_reduction(m)
4176
        ),
4177
        check_sum_reduction=True,
4178
        check_gradgrad=False,
4179
    ),
4180
    dict(
4181
        module_name="MultiMarginLoss",
4182
        input_size=(10,),
4183
        target_fn=lambda: torch.rand(1).mul(8).floor().long(),
4184
        reference_fn=lambda i, t, m: multimarginloss_reference(
4185
            i, t, reduction=get_reduction(m)
4186
        ),
4187
        desc="1d",
4188
        check_sum_reduction=True,
4189
        check_gradgrad=False,
4190
    ),
4191
    dict(
4192
        module_name="MultiMarginLoss",
4193
        constructor_args=(2,),
4194
        cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(2)",
4195
        input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
4196
        target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4197
        reference_fn=lambda i, t, m: multimarginloss_reference(
4198
            i, t, p=2, reduction=get_reduction(m)
4199
        ),
4200
        desc="p",
4201
        check_sum_reduction=True,
4202
        check_gradgrad=False,
4203
    ),
4204
    dict(
4205
        module_name="MultiMarginLoss",
4206
        constructor_args=(1, 0.5),
4207
        cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(1).margin(0.5)",
4208
        legacy_constructor_args=(1, None, 0.5),
4209
        input_size=(5, 10),
4210
        target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4211
        reference_fn=lambda i, t, m: multimarginloss_reference(
4212
            i, t, margin=0.5, reduction=get_reduction(m)
4213
        ),
4214
        desc="margin",
4215
        check_sum_reduction=True,
4216
        check_gradgrad=False,
4217
    ),
4218
    dict(
4219
        module_name="MultiMarginLoss",
4220
        constructor_args=(1, 1.0, torch.rand(10)),
4221
        cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))",
4222
        legacy_constructor_args=(1, torch.rand(10)),
4223
        input_size=(5, 10),
4224
        target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4225
        reference_fn=lambda i, t, m: multimarginloss_reference(
4226
            i, t, weight=get_weight(m), reduction=get_reduction(m)
4227
        ),
4228
        desc="weights",
4229
        check_sum_reduction=True,
4230
        check_gradgrad=False,
4231
    ),
4232
    dict(
4233
        module_name="SmoothL1Loss",
4234
        input_size=(5, 10),
4235
        target_size=(5, 10),
4236
        check_sum_reduction=True,
4237
        reference_fn=lambda i, t, m: smoothl1loss_reference(
4238
            i, t, reduction=get_reduction(m)
4239
        ),
4240
    ),
4241
    dict(
4242
        module_name="SoftMarginLoss",
4243
        input_size=(5, 5),
4244
        target_fn=lambda: torch.randn(5, 5).sign(),
4245
        reference_fn=lambda i, t, m: softmarginloss_reference(
4246
            i, t, reduction=get_reduction(m)
4247
        ),
4248
        check_sum_reduction=True,
4249
    ),
4250
    dict(
4251
        module_name="CosineEmbeddingLoss",
4252
        input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
4253
        target_fn=lambda: torch.randn(15).sign(),
4254
        reference_fn=lambda i, t, m: cosineembeddingloss_reference(
4255
            i[0], i[1], t, reduction=get_reduction(m)
4256
        ),
4257
        check_sum_reduction=True,
4258
    ),
4259
    dict(
4260
        module_name="CosineEmbeddingLoss",
4261
        constructor_args=(0.7,),
4262
        cpp_constructor_args="torch::nn::CosineEmbeddingLossOptions().margin(0.7)",
4263
        input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
4264
        target_fn=lambda: torch.randn(15).sign(),
4265
        reference_fn=lambda i, t, m: cosineembeddingloss_reference(
4266
            i[0], i[1], t, margin=0.7, reduction=get_reduction(m)
4267
        ),
4268
        desc="margin",
4269
        check_sum_reduction=True,
4270
    ),
4271
    dict(
4272
        module_name="MarginRankingLoss",
4273
        input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
4274
        target_fn=lambda: torch.randn(50).sign(),
4275
        reference_fn=lambda i, t, m: marginrankingloss_reference(
4276
            i[0], i[1], t, reduction=get_reduction(m)
4277
        ),
4278
        check_sum_reduction=True,
4279
    ),
4280
    dict(
4281
        module_name="MarginRankingLoss",
4282
        constructor_args=(0.5,),
4283
        cpp_constructor_args="torch::nn::MarginRankingLossOptions().margin(0.5)",
4284
        input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
4285
        target_fn=lambda: torch.randn(50).sign(),
4286
        reference_fn=lambda i, t, m: marginrankingloss_reference(
4287
            i[0], i[1], t, margin=0.5, reduction=get_reduction(m)
4288
        ),
4289
        desc="margin",
4290
        check_sum_reduction=True,
4291
    ),
4292
]
4293

4294
new_criterion_tests = [
4295
    dict(
4296
        module_name="BCEWithLogitsLoss",
4297
        input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4298
        target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4299
    ),
4300
    dict(
4301
        module_name="BCEWithLogitsLoss",
4302
        constructor_args=(torch.rand(10),),
4303
        cpp_constructor_args="torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))",
4304
        input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4305
        target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4306
        desc="weights",
4307
    ),
4308
    dict(
4309
        module_name="BCEWithLogitsLoss",
4310
        constructor_args=(torch.rand(()),),
4311
        cpp_constructor_args="torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))",
4312
        input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
4313
        target_fn=lambda: torch.randn(()).gt(0).double(),
4314
        desc="scalar_weights",
4315
    ),
4316
    dict(
4317
        module_name="NLLLoss",
4318
        input_size=(2, 3, 5, 5),
4319
        target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
4320
        reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4321
            i, t, reduction=get_reduction(m)
4322
        ),
4323
        check_sum_reduction=True,
4324
        desc="2d",
4325
        check_bfloat16=TEST_WITH_ROCM,
4326
    ),
4327
    dict(
4328
        module_name="NLLLoss",
4329
        constructor_args_fn=lambda: (torch.rand(3),),
4330
        cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(3))",
4331
        input_size=(2, 3, 5, 5),
4332
        target=torch.rand(2, 5, 5).mul(3).floor().long(),
4333
        reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4334
            i, t, weight=get_weight(m)
4335
        ),
4336
        desc="2d_weights",
4337
        check_bfloat16=TEST_WITH_ROCM,
4338
    ),
4339
    dict(
4340
        module_name="NLLLoss",
4341
        constructor_args=(None, None, 1),
4342
        cpp_constructor_args="torch::nn::NLLLossOptions().weight({}).ignore_index(1)",
4343
        input_size=(2, 3, 5, 5),
4344
        target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
4345
        reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4346
            i, t, ignore_index=1
4347
        ),
4348
        desc="2d_ignore_index",
4349
        check_bfloat16=TEST_WITH_ROCM,
4350
    ),
4351
    dict(
4352
        module_name="NLLLoss",
4353
        input_size=(2, 3, 5, 5, 2, 2),
4354
        target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
4355
        reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4356
            i, t, reduction=get_reduction(m)
4357
        ),
4358
        check_sum_reduction=True,
4359
        desc="higher_dim",
4360
        check_bfloat16=TEST_WITH_ROCM,
4361
    ),
4362
    dict(
4363
        module_name="NLLLoss",
4364
        input_size=(2, 3, 5),
4365
        target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
4366
        reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4367
            i, t, reduction=get_reduction(m)
4368
        ),
4369
        check_sum_reduction=True,
4370
        desc="dim_is_3",
4371
        check_bfloat16=TEST_WITH_ROCM,
4372
    ),
4373
    dict(
4374
        module_name="PoissonNLLLoss",  # Default is log_input=True, full=False
4375
        input_size=(2, 3, 4, 5),
4376
        target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4377
        reference_fn=lambda i, t, _: (i.exp() - t.mul(i)).mean(),
4378
        desc="no_full_loss",
4379
    ),
4380
    dict(
4381
        module_name="PoissonNLLLoss",
4382
        constructor_args=(False, False),  # log_input=False, full=False
4383
        cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(false).full(false)",
4384
        input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
4385
        target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4386
        reference_fn=lambda i, t, _: (i - t.mul((i + 1e-8).log())).mean(),
4387
        desc="no_full_loss_no_log_input",
4388
    ),
4389
    dict(
4390
        module_name="PoissonNLLLoss",
4391
        constructor_args=(True, True),  # log_input=True, full=True
4392
        cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(true).full(true)",
4393
        input_size=(2, 3, 4, 5),
4394
        target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4395
        reference_fn=lambda i, t, _: (
4396
            i.exp()
4397
            - t.mul(i)
4398
            + (t.mul(t.log()) - t + 0.5 * (2.0 * pi * t).log()).masked_fill(t <= 1, 0)
4399
        ).mean(),
4400
        desc="full_loss",
4401
    ),
4402
    dict(
4403
        module_name="PoissonNLLLoss",
4404
        constructor_args=(False, True),  # log_input=False, full=True
4405
        cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(false).full(true)",
4406
        input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
4407
        target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4408
        reference_fn=lambda i, t, _: (
4409
            i
4410
            - t.mul((i + 1e-8).log())
4411
            + (t.mul(t.log()) - t + 0.5 * (2.0 * pi * t).log()).masked_fill(t <= 1, 0)
4412
        ).mean(),
4413
        desc="full_loss_no_log_input",
4414
    ),
4415
    dict(
4416
        module_name="L1Loss",
4417
        input_size=(),
4418
        target_size=(),
4419
        reference_fn=lambda i, t, _: 1.0 / i.numel() * (i - t).abs().sum(),
4420
        desc="scalar",
4421
    ),
4422
    dict(
4423
        module_name="KLDivLoss",
4424
        input_fn=lambda: torch.rand(()).log(),
4425
        target_fn=lambda: torch.rand(()),
4426
        reference_fn=lambda i, t, m: kldivloss_reference(i, t, get_reduction(m)),
4427
        check_sum_reduction=True,
4428
        desc="scalar",
4429
    ),
4430
    dict(
4431
        module_name="MSELoss",
4432
        input_size=(),
4433
        target_size=(),
4434
        reference_fn=lambda i, t, m: (
4435
            (i - t).abs().pow(2).sum()
4436
            / (i.numel() if get_reduction(m) == "mean" else 1)
4437
        ),
4438
        check_sum_reduction=True,
4439
        desc="scalar",
4440
        check_bfloat16=TEST_WITH_ROCM,
4441
    ),
4442
    dict(
4443
        module_name="MSELoss",
4444
        input_fn=lambda: torch.ones(5, 68, 64, 64, dtype=torch.float) / 10,
4445
        target_fn=lambda: torch.zeros(5, 68, 64, 64, dtype=torch.float),
4446
        reference_fn=lambda i, t, m: (
4447
            (i - t).abs().pow(2).sum()
4448
            / (i.numel() if get_reduction(m) == "mean" else 1)
4449
        ),
4450
        check_forward_only=True,
4451
        desc="prec",
4452
        check_bfloat16=TEST_WITH_ROCM,
4453
    ),
4454
    dict(
4455
        module_name="BCELoss",
4456
        constructor_args_fn=lambda: (torch.rand(()),),
4457
        cpp_constructor_args="torch::nn::BCELossOptions().weight(torch::rand({}))",
4458
        input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
4459
        target_fn=lambda: torch.rand(()).gt(0).double(),
4460
        reference_fn=lambda i, t, m: -(
4461
            (t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)
4462
        ).sum()
4463
        / (i.numel() if get_reduction(m) == "mean" else 1),
4464
        desc="scalar_weights",
4465
        check_gradgrad=False,
4466
        check_bfloat16=TEST_WITH_ROCM,
4467
    ),
4468
    dict(
4469
        module_name="HingeEmbeddingLoss",
4470
        constructor_args=(0.5,),
4471
        cpp_constructor_args="torch::nn::HingeEmbeddingLossOptions().margin(0.5)",
4472
        input_size=(),
4473
        target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1),
4474
        desc="scalar_margin",
4475
        check_sum_reduction=True,
4476
    ),
4477
    dict(
4478
        module_name="SmoothL1Loss",
4479
        input_size=(),
4480
        target_size=(),
4481
        check_sum_reduction=True,
4482
        reference_fn=lambda i, t, m: smoothl1loss_reference(
4483
            i, t, reduction=get_reduction(m)
4484
        ),
4485
        desc="scalar",
4486
    ),
4487
    dict(
4488
        module_name="MultiLabelSoftMarginLoss",
4489
        constructor_args=(torch.rand(10),),
4490
        cpp_constructor_args="torch::nn::MultiLabelSoftMarginLossOptions().weight(torch::rand(10))",
4491
        input_fn=lambda: torch.randn(5, 10),
4492
        target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
4493
        reference_fn=lambda i, t, m: -(
4494
            (t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)
4495
        ).sum()
4496
        / (
4497
            i.numel()
4498
            if get_reduction(m) == "mean"
4499
            else i.size(1)
4500
            if get_reduction(m) == "sum"
4501
            else 1
4502
        ),
4503
        desc="weights",
4504
        check_sum_reduction=True,
4505
        check_gradgrad=False,
4506
    ),
4507
    dict(
4508
        module_name="CTCLoss",
4509
        constructor_args=(14,),  # blank=14
4510
        extra_args=([50, 50, 50], [30, 25, 20]),  # input_lengths, target_lengths
4511
        input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4512
        target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4513
        reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4514
            i, t, il, tl, blank=14, reduction=get_reduction(m)
4515
        ),
4516
        desc="lengths_intlists",
4517
        check_sum_reduction=True,
4518
        check_gradgrad=False,
4519
        check_half=False,
4520
        # `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
4521
        test_cpp_api_parity=False,
4522
    ),
4523
    dict(
4524
        module_name="CTCLoss",
4525
        constructor_args=(14,),  # blank=14
4526
        cpp_constructor_args="torch::nn::CTCLossOptions().blank(14)",
4527
        extra_args=(
4528
            torch.tensor([50, 50, 50]),
4529
            torch.tensor([30, 25, 20]),
4530
        ),  # input_lengths, target_lengths
4531
        input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4532
        target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4533
        reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4534
            i, t, il, tl, blank=14, reduction=get_reduction(m)
4535
        ),
4536
        desc="lengths_tensors",
4537
        check_sum_reduction=True,
4538
        check_gradgrad=False,
4539
        check_half=False,
4540
    ),
4541
    # Test is flaky
4542
    # See https://github.com/pytorch/pytorch/issues/29380.
4543
    # dict(
4544
    #     module_name='CTCLoss',
4545
    #     desc='1d_target',
4546
    #     constructor_args=(14,),  # blank=14
4547
    #     extra_args=([50, 50, 50], [30, 25, 20]),  # input_lengths, target_lengths
4548
    #     input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4549
    #     target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4550
    #     reference_fn=lambda i, t, il, tl, m:
4551
    #         ctcloss_reference(i, t, il, tl, blank=14, reduction=get_reduction(m)),
4552
    #     check_sum_reduction=True,
4553
    #     check_gradgrad=False,
4554
    #     check_half=False,
4555
    # ),
4556
    dict(
4557
        module_name="CTCLoss",
4558
        desc="2d_int_target_lengths_intlists",
4559
        constructor_args=(0,),  # blank=0
4560
        extra_args=([50, 50, 50], [30, 25, 20]),  # input_lengths, target_lengths
4561
        input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4562
        target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4563
        reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4564
            i, t, il, tl, blank=0, reduction=get_reduction(m)
4565
        ),
4566
        check_sum_reduction=True,
4567
        check_gradgrad=False,
4568
        check_half=False,
4569
        convert_target=False,
4570
        # `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
4571
        test_cpp_api_parity=False,
4572
    ),
4573
    dict(
4574
        module_name="CTCLoss",
4575
        desc="2d_int_target_lengths_tensors",
4576
        constructor_args=(0,),  # blank=0
4577
        cpp_constructor_args="torch::nn::CTCLossOptions().blank(0)",
4578
        extra_args=(
4579
            torch.tensor([50, 50, 50]),
4580
            torch.tensor([30, 25, 20]),
4581
        ),  # input_lengths, target_lengths
4582
        input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4583
        target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4584
        reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4585
            i, t, il, tl, blank=0, reduction=get_reduction(m)
4586
        ),
4587
        check_sum_reduction=True,
4588
        check_gradgrad=False,
4589
        check_half=False,
4590
        convert_target=False,
4591
    ),
4592
    dict(
4593
        module_name="CTCLoss",
4594
        desc="2d_lengths_tensors",
4595
        constructor_args=(0,),  # blank=0
4596
        cpp_constructor_args="torch::nn::CTCLossOptions().blank(0)",
4597
        extra_args=(
4598
            torch.tensor([50, 50, 50]),
4599
            torch.tensor([30, 25, 20]),
4600
        ),  # input_lengths, target_lengths
4601
        input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4602
        target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4603
        reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4604
            i, t, il, tl, blank=0, reduction=get_reduction(m)
4605
        ),
4606
        check_sum_reduction=True,
4607
        check_gradgrad=False,
4608
        check_half=False,
4609
        convert_target=False,
4610
    ),
4611
]
4612

4613

4614
class NNTestCase(TestCase):
4615
    def _jacobian(self, input, num_out):
4616
        if isinstance(input, tuple):
4617
            return tuple(self._jacobian(elem, num_out) for elem in input)
4618
        elif isinstance(input, list):
4619
            return [self._jacobian(elem, num_out) for elem in input]
4620
        else:
4621
            return torch.zeros(input.nelement(), num_out)
4622

4623
    def _flatten_tensors(self, x):
4624
        if isinstance(x, torch.Tensor):
4625
            if x.is_sparse:
4626
                return x.to_dense().view(-1)
4627
            else:
4628
                return x.view(-1)
4629
        else:
4630
            return tuple(self._flatten_tensors(a) for a in x)
4631

4632
    def _zero_grad_input(self, input):
4633
        if isinstance(input, torch.Tensor):
4634
            if input.requires_grad and input.grad is not None:
4635
                input.grad.zero_()
4636
                input.grad.detach_()
4637
        else:
4638
            for i in input:
4639
                self._zero_grad_input(i)
4640

4641
    def _analytical_jacobian(
4642
        self, module, input, jacobian_input=True, jacobian_parameters=True
4643
    ):
4644
        output = self._forward(module, input)
4645
        output_size = output.nelement()
4646

4647
        if jacobian_input:
4648
            jacobian_inp = self._jacobian(input, output_size)
4649
            flat_jacobian_input = list(iter_tensors(jacobian_inp))
4650

4651
        if jacobian_parameters:
4652
            num_param = sum(p.numel() for p in self._get_parameters(module)[0])
4653
            jacobian_param = torch.zeros(num_param, output_size)
4654

4655
        for i in range(output_size):
4656
            param, d_param = self._get_parameters(module)
4657
            # make non grad zeros
4658
            d_param = [
4659
                torch.zeros_like(p) if d is None else d
4660
                for (p, d) in zip(param, d_param)
4661
            ]
4662
            d_out = torch.zeros_like(output)
4663
            flat_d_out = d_out.view(-1)
4664
            flat_d_out[i] = 1
4665
            if jacobian_parameters:
4666
                self._zero_grad_parameters(module)
4667
            # Tensors will accumulate gradient from multiple steps
4668
            if jacobian_input:
4669
                self._zero_grad_input(input)
4670
            d_input = self._backward(module, input, output, d_out)
4671
            if jacobian_input:
4672
                for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
4673
                    jacobian_x[:, i] = d_x.contiguous().view(-1)
4674
            if jacobian_parameters:
4675
                jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
4676

4677
        res = tuple()
4678
        if jacobian_input:
4679
            res += (jacobian_inp,)
4680
        if jacobian_parameters:
4681
            res += (jacobian_param,)
4682

4683
        return res
4684

4685
    def _numerical_jacobian(
4686
        self, module, input, jacobian_input=True, jacobian_parameters=True
4687
    ):
4688
        def fw(input):
4689
            return self._forward(module, input).detach()
4690

4691
        res = tuple()
4692
        if jacobian_input:
4693
            res += (get_numerical_jacobian(fw, input, eps=1e-6),)
4694
        if jacobian_parameters:
4695
            param, _ = self._get_parameters(module)
4696
            res += (
4697
                torch.cat(
4698
                    [get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0
4699
                ),
4700
            )
4701
        return res
4702

4703
    def check_jacobian(self, module, input, jacobian_input=True):
4704
        jacobian_parameters = bool(self._get_parameters(module)[0])
4705
        analytical = self._analytical_jacobian(
4706
            module, input, jacobian_input, jacobian_parameters
4707
        )
4708
        numerical = self._numerical_jacobian(
4709
            module, input, jacobian_input, jacobian_parameters
4710
        )
4711
        analytical_t = list(iter_tensors(analytical))
4712
        numerical_t = list(iter_tensors(numerical))
4713

4714
        # TODO: compare structure
4715
        if input.numel() != 0:
4716
            self.assertLessEqual(
4717
                max(
4718
                    a.add(n, alpha=-1).abs().max()
4719
                    for a, n in zip(analytical_t, numerical_t)
4720
                ),
4721
                PRECISION,
4722
            )
4723

4724
    def check_criterion_jacobian(self, criterion, input, target):
4725
        eps = 1e-6
4726
        self._forward_criterion(criterion, input, target)
4727
        analytical_d_x = self._backward_criterion(criterion, input, target)
4728
        numerical_d_x = deepcopy(analytical_d_x)
4729

4730
        input_t = iter_tensors(input)
4731
        numerical_t = iter_tensors(numerical_d_x)
4732
        for x, d_x in zip(input_t, numerical_t):
4733
            x = x.view(-1).data
4734
            d_x = d_x.view(-1).data
4735
            for i in range(x.nelement()):
4736
                original = x[i].item()
4737
                x[i] = original + eps
4738
                fx1 = self._forward_criterion(criterion, input, target)
4739
                x[i] = original - eps
4740
                fx2 = self._forward_criterion(criterion, input, target)
4741
                deriv = (fx1 - fx2) / (2.0 * eps)
4742
                d_x[i] = float(deriv)
4743
                x[i] = original
4744

4745
        # TODO: check structure
4746
        analytical_t = list(iter_tensors(analytical_d_x))
4747
        numerical_t = list(iter_tensors(numerical_d_x))
4748

4749
        self.assertLessEqual(
4750
            max(
4751
                a.add(n, alpha=-1).abs().max()
4752
                for a, n in zip(analytical_t, numerical_t)
4753
            ),
4754
            PRECISION,
4755
        )
4756

4757

4758
class TestBase(object):
4759
    _required_arg_names = {"constructor_args", "input", "extra_args"}
4760

4761
    def __init__(
4762
        self, constructor, desc="", reference_fn=None, fullname=None, **kwargs
4763
    ):
4764
        self.desc = desc
4765
        self.fullname = fullname
4766
        self.constructor = constructor
4767
        self.reference_fn = reference_fn
4768
        for name in self._required_arg_names:
4769
            if (
4770
                name not in kwargs
4771
                and name + "_fn" not in kwargs
4772
                and name + "_size" not in kwargs
4773
            ):
4774
                if name in {"constructor_args", "extra_args"}:
4775
                    kwargs[name] = tuple()
4776
                else:
4777
                    raise ValueError(
4778
                        "{}: Specify {} by a value, a function to generate it, or it's size!".format(
4779
                            self.get_name(), name
4780
                        )
4781
                    )
4782
        self._extra_kwargs = kwargs
4783
        self._arg_cache = {}
4784

4785
    def get_name(self):
4786
        if self.fullname is not None:
4787
            return "test_" + self.fullname
4788

4789
        test_name = "test_" + self.constructor.__name__
4790
        if self.desc:
4791
            test_name += "_" + self.desc
4792
        return test_name
4793

4794
    def _unpack(self, value):
4795
        if isinstance(value, torch.Tensor):
4796
            return value
4797
        elif is_iterable(value):
4798
            return type(value)(self._unpack(v) for v in value)
4799
        else:
4800
            return value
4801

4802
    @property
4803
    def constructor_args(self):
4804
        return self._get_arg("constructor_args", True)
4805

4806
    @property
4807
    def extra_args(self):
4808
        return self._get_arg("extra_args", True)
4809

4810
    def _get_arg(self, name, unpack):
4811
        assert name in self._required_arg_names
4812

4813
        if name not in self._arg_cache:
4814
            fn_name = name + "_fn"
4815
            size_name = name + "_size"
4816

4817
            if name in self._extra_kwargs:
4818
                self._arg_cache[name] = self._extra_kwargs[name]
4819
            elif fn_name in self._extra_kwargs:
4820
                self._arg_cache[name] = self._extra_kwargs[fn_name]()
4821
            else:
4822
                assert (
4823
                    size_name in self._extra_kwargs
4824
                ), "Missing `{}`, `{}` or `{}` for {}".format(
4825
                    name, size_name, fn_name, self.get_name()
4826
                )
4827

4828
                def map_tensor_sizes(sizes):
4829
                    if isinstance(sizes, list):
4830
                        return [map_tensor_sizes(s) for s in sizes]
4831
                    elif isinstance(sizes, torch.Tensor):
4832
                        return sizes.double()
4833
                    else:
4834
                        return torch.randn(sizes)
4835

4836
                self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
4837

4838
        return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
4839

4840
    def _get_input(self, unpack=True):
4841
        return self._get_arg("input", unpack)
4842

4843
    def __call__(self, test_case):
4844
        raise NotImplementedError
4845

4846

4847
class ModuleTest(TestBase):
4848
    def __init__(self, *args, **kwargs):
4849
        super(ModuleTest, self).__init__(*args, **kwargs)
4850
        self.jacobian_input = kwargs.get("jacobian_input", True)
4851
        self.should_test_cuda = kwargs.get("test_cuda", True)
4852
        self.should_test_pickle = kwargs.get("pickle", True)
4853
        self.check_gradgrad = kwargs.get("check_gradgrad", True)
4854
        self.FIXME_no_cuda_gradgrad_comparison = kwargs.get(
4855
            "FIXME_no_cuda_gradgrad_comparison", False
4856
        )
4857
        self.precision = kwargs.get("precision", 2e-4)
4858
        self.check_forward_only = kwargs.get("check_forward_only", False)
4859

4860
    def __call__(self, test_case):
4861
        module = self.constructor(*self.constructor_args).to("xpu")
4862
        input = self._get_input()
4863

4864
        if self.reference_fn is not None:
4865
            out = test_case._forward(module, input)
4866
            ref_input = deepcopy(input)
4867
            ref_module = deepcopy(module)
4868
            expected_out = self.reference_fn(
4869
                ref_input, test_case._get_parameters(module)[0], ref_module
4870
            )
4871
            test_case.assertEqual(out, expected_out)
4872
        unsupported_backward_modules = [
4873
            "Conv1d",
4874
            "Conv2d",
4875
            "Conv3d",
4876
            "ConvTranspose1d",
4877
            "ConvTranspose2d",
4878
            "ConvTranspose3d",
4879
        ]
4880
        if (
4881
            module._get_name() in unsupported_backward_modules
4882
            and input.dtype == torch.float64
4883
        ):
4884
            return
4885
        if self.check_forward_only:
4886
            return
4887
        self.test_noncontig(test_case, module, input)
4888

4889
        if self.should_test_pickle:
4890
            # TODO: do this with in-memory files as soon as torch.save will support it
4891
            with TemporaryFile() as f:
4892
                test_case._forward(module, input)
4893
                torch.save(module, f)
4894
                f.seek(0)
4895
                module_copy = torch.load(f)
4896
                test_case.assertEqual(
4897
                    test_case._forward(module, input),
4898
                    test_case._forward(module_copy, input),
4899
                )
4900

4901
        self._do_test(test_case, module, input)
4902

4903
    def noncontiguize(self, obj):
4904
        if isinstance(obj, list):
4905
            return [self.noncontiguize(o) for o in obj]
4906
        tensor = obj
4907
        ndim = tensor.dim()
4908
        # Always making only the last dimension noncontiguous is easy to hide
4909
        # bugs because .view(-1) will still work. So try to find a dim with size
4910
        # > 1 and make that non-contiguous, i.e., stack + select on the
4911
        # dimension directly after that.
4912
        dim = ndim
4913
        for d in range(ndim):
4914
            if tensor.size(d) > 1:
4915
                dim = d + 1
4916
                break
4917
        noncontig = (
4918
            torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
4919
        )
4920
        assert (
4921
            noncontig.numel() == 1
4922
            or noncontig.numel() == 0
4923
            or not noncontig.is_contiguous()
4924
        )
4925
        noncontig.requires_grad = tensor.requires_grad
4926
        return noncontig
4927

4928
    def test_noncontig(self, test_case, module, input):
4929
        # check no scalars, can't make non-contig
4930
        if isinstance(input, torch.Tensor) and input.dim() == 0:
4931
            return
4932
        if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
4933
            return
4934

4935
        test_case._zero_grad_parameters(module)
4936
        test_case._zero_grad_input(input)
4937
        with freeze_rng_state():
4938
            output = test_case._forward(module, input)
4939
            grad_output = output.new(output.shape).normal_()
4940
            output = output.clone()
4941
            d_input = deepcopy(test_case._backward(module, input, output, grad_output))
4942
            d_param = deepcopy(test_case._get_parameters(module)[1])
4943

4944
        nc_input = self.noncontiguize(input)
4945
        nc_grad_output = self.noncontiguize(grad_output)
4946
        for contig_i, contig_g in product((True, False), repeat=2):
4947
            i = input if contig_i else nc_input
4948
            # Some ops, e.g., nn.Flatten, return gradient that shares
4949
            # storage with the grad_output. Hence we copy here.
4950
            go = deepcopy(grad_output if contig_g else nc_grad_output)
4951
            test_case._zero_grad_parameters(module)
4952
            test_case._zero_grad_input(i)
4953
            with freeze_rng_state():
4954
                out = test_case._forward(module, i)
4955
                grad = test_case._backward(module, i, out, go)
4956

4957
                test_case.assertEqual(out, output)
4958
                test_case.assertEqual(grad, d_input, 1e-4)
4959
                test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
4960

4961
    def test_cuda(self, test_case):
4962
        if not TEST_CUDA or not self.should_test_cuda:
4963
            raise unittest.SkipTest("Excluded from CUDA tests")
4964
        try:
4965
            cpu_input = self._get_input()
4966
            type_map = {"torch.DoubleTensor": torch.cuda.FloatTensor}
4967
            gpu_input = to_gpu(cpu_input, type_map=type_map)
4968

4969
            cpu_module = self.constructor(*self.constructor_args)
4970
            gpu_module = self.constructor(*self.constructor_args).float().cuda()
4971
            cpu_param = test_case._get_parameters(cpu_module)
4972
            gpu_param = test_case._get_parameters(gpu_module)
4973
            for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
4974
                gpu_p.data.copy_(cpu_p)
4975

4976
            test_case._zero_grad_input(cpu_input)
4977
            test_case._zero_grad_input(gpu_input)
4978
            test_case._zero_grad_parameters(cpu_module)
4979
            test_case._zero_grad_parameters(gpu_module)
4980
            cpu_output = test_case._forward(cpu_module, cpu_input)
4981
            gpu_output = test_case._forward(gpu_module, gpu_input)
4982
            test_case.assertEqual(cpu_output, gpu_output, self.precision)
4983

4984
            # Run backwards on CPU and GPU and compare results
4985
            for _ in range(5):
4986
                cpu_gradOutput = cpu_output.clone().normal_()
4987
                gpu_gradOutput = cpu_gradOutput.type("torch.cuda.FloatTensor")
4988
                cpu_gradInput = test_case._backward(
4989
                    cpu_module, cpu_input, cpu_output, cpu_gradOutput
4990
                )
4991
                gpu_gradInput = test_case._backward(
4992
                    gpu_module, gpu_input, gpu_output, gpu_gradOutput
4993
                )
4994
                test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
4995
                for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
4996
                    test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
4997

4998
            # Run double-backwards on CPU and GPU and compare results
4999
            if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
5000
                cpu_output = cpu_module(cpu_input)
5001
                gpu_output = gpu_module(gpu_input)
5002

5003
                cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
5004
                gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
5005
                gpu_gradOutput.requires_grad = True
5006

5007
                cpu_gradInputs = torch.autograd.grad(
5008
                    cpu_output,
5009
                    (cpu_input,) + tuple(cpu_module.parameters()),
5010
                    cpu_gradOutput,
5011
                    create_graph=True,
5012
                )
5013
                gpu_gradInputs = torch.autograd.grad(
5014
                    gpu_output,
5015
                    (gpu_input,) + tuple(gpu_module.parameters()),
5016
                    gpu_gradOutput,
5017
                    create_graph=True,
5018
                )
5019

5020
                for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
5021
                    test_case.assertEqual(cpu_d_i, gpu_d_i, self.precision)
5022

5023
                # We mix output into the second backwards computation so that
5024
                # torch.autograd.grad doesn't complain that some inputs
5025
                # are unreachable (which can happen if you differentiate
5026
                # only on the gradient.
5027
                cpu_gg = torch.autograd.grad(
5028
                    cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)),
5029
                    (cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
5030
                    retain_graph=True,
5031
                )
5032
                gpu_gg = torch.autograd.grad(
5033
                    gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)),
5034
                    (gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
5035
                    retain_graph=True,
5036
                )
5037

5038
                test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
5039
                for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
5040
                    test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
5041

5042
            self.test_noncontig(test_case, gpu_module, gpu_input)
5043
        except NotImplementedError:
5044
            pass
5045
        # TODO: remove this after CUDA scatter_ is implemented
5046
        except AttributeError as e:
5047
            if (
5048
                len(e.args) == 1
5049
                and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]
5050
            ):
5051
                pass
5052
            else:
5053
                raise
5054

5055

5056
class CriterionTest(TestBase):
5057
    _required_arg_names = TestBase._required_arg_names.union({"target"})
5058

5059
    def __init__(self, *args, **kwargs):
5060
        super(CriterionTest, self).__init__(*args, **kwargs)
5061
        self.should_test_cuda = kwargs.get("test_cuda", True)
5062
        self.check_forward_only = kwargs.get("check_forward_only", True)
5063

5064
    def _get_target(self):
5065
        return self._get_arg("target", True)
5066

5067
    def __call__(self, test_case):
5068
        module = self.constructor(*self.constructor_args)
5069
        input = self._get_input()
5070

5071
        # Check that these methods don't raise errors
5072
        module.__repr__()
5073
        str(module)
5074

5075
        target = self._get_target()
5076

5077
        if self.reference_fn is not None:
5078
            out = test_case._forward_criterion(
5079
                module, input, target, extra_args=self.extra_args
5080
            )
5081
            ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
5082
            expected_out = self.reference_fn(*ref_args)
5083
            test_case.assertEqual(out, expected_out)
5084

5085
        if self.check_forward_only:
5086
            return
5087

5088
        test_case.check_criterion_jacobian(module, input, target)
5089
        self._do_extra_tests(test_case, module, input, target)
5090

5091
    def test_cuda(self, test_case):
5092
        if not TEST_CUDA or not self.should_test_cuda:
5093
            raise unittest.SkipTest("Excluded from CUDA tests")
5094
        try:
5095
            cpu_input = self._get_input()
5096
            type_map = {
5097
                "torch.DoubleTensor": torch.cuda.FloatTensor,
5098
            }
5099
            gpu_input = to_gpu(cpu_input, type_map=type_map)
5100

5101
            cpu_target = self._get_target()
5102
            gpu_target = to_gpu(cpu_target, type_map=type_map)
5103

5104
            cpu_module = self.constructor(*self.constructor_args)
5105
            gpu_module = self.constructor(*self.constructor_args).float().cuda()
5106

5107
            cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
5108
            gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
5109
            test_case.assertEqual(cpu_output, gpu_output, 4e-4)
5110

5111
            gradOutput = torch.randn(())
5112
            cpu_gradInput = test_case._backward_criterion(
5113
                cpu_module, cpu_input, cpu_target, gradOutput
5114
            )
5115
            gpu_gradInput = test_case._backward_criterion(
5116
                gpu_module, gpu_input, gpu_target, gradOutput
5117
            )
5118
            test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
5119
        except NotImplementedError:
5120
            pass
5121

5122
    def _do_extra_tests(self, test_case, module, input, target):
5123
        pass
5124

5125

5126
class InputVariableMixin(object):
5127
    def _get_input(self):
5128
        input = TestBase._get_input(self, False)
5129

5130
        def map_variables(i):
5131
            if isinstance(i, torch.Tensor):
5132
                if i.is_floating_point():
5133
                    i.requires_grad = True
5134
                return i
5135
            else:
5136
                return type(i)(map_variables(elem) for elem in i)
5137

5138
        return map_variables(input)
5139

5140

5141
class NewModuleTest(InputVariableMixin, ModuleTest):
5142
    def __init__(self, *args, **kwargs):
5143
        super(NewModuleTest, self).__init__(*args, **kwargs)
5144
        self.cudnn = kwargs.get("cudnn", False)
5145
        self.check_inplace = kwargs.get("check_inplace", False)
5146
        self.check_gradgrad = kwargs.get("check_gradgrad", True)
5147
        self.skip_double = kwargs.get("skip_double", False)
5148

5149
    def _do_test(self, test_case, module, input):
5150
        test_case.check_jacobian(module, input, self.jacobian_input)
5151

5152
        if self.check_gradgrad:
5153
            # could probably unify check_jacobian above with this.
5154
            params = tuple(x for x in module.parameters())
5155
            _assertGradAndGradgradChecks(
5156
                test_case,
5157
                lambda x, *args, **kw: test_case._forward(module, x),
5158
                (input,) + params,
5159
            )
5160

5161
        # check if module can be printed
5162
        module.__repr__()
5163

5164
        if self.check_inplace:
5165
            # check if the inplace variant of the module gives the same result
5166
            # as the out-of-place
5167

5168
            module_ip = self.constructor(*self.constructor_args, inplace=True)
5169

5170
            input_version = input._version
5171
            with freeze_rng_state():
5172
                output = module(input)
5173
            test_case.assertEqual(input._version, input_version)
5174

5175
            input_ip = deepcopy(input)
5176
            if input.device.type == "xpu":
5177
                input_ip.requires_grad = True
5178
            input_ip_clone = input_ip.clone()
5179
            with freeze_rng_state():
5180
                output_ip = module_ip(input_ip_clone)
5181
            if input.device == torch.device("cpu"):
5182
                test_case.assertNotEqual(input_ip_clone._version, input_version)
5183
            test_case.assertEqual(output, output_ip)
5184
            grad = output.data.clone().normal_()
5185
            input.grad.data.zero_()
5186
            output.backward(grad)
5187
            output_ip.backward(grad)
5188
            test_case.assertEqual(input.grad, input_ip.grad)
5189

5190
        if isinstance(input, torch.LongTensor) and TEST_CUDA:
5191
            # check that cuda() moves module parameters to correct GPU device,
5192
            # and that float() casts parameters correctly
5193

5194
            input = input.cuda()
5195
            module.float().cuda()
5196
            module(input)
5197
            for p in module.parameters():
5198
                test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5199
                test_case.assertEqual(p.get_device(), 0)
5200

5201
            if torch.cuda.device_count() > 1:
5202
                input = input.cuda(1)
5203
                module.cuda(1)
5204
                with torch.cuda.device(1):
5205
                    module(input)
5206
                for p in module.parameters():
5207
                    test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5208
                    test_case.assertEqual(p.get_device(), 1)
5209
        else:
5210
            # check that float()/double() casters work correctly
5211

5212
            # to float
5213
            if input.device == torch.device("cpu"):
5214
                if not isinstance(input, torch.LongTensor):
5215
                    input = input.float()
5216
                module.float()
5217
                module(input)
5218
                for p in module.parameters():
5219
                    test_case.assertIsInstance(p, torch.FloatTensor)
5220

5221
                # and back to double
5222
                if not isinstance(input, torch.LongTensor):
5223
                    input = input.double()
5224
                module.double()
5225
                module(input)
5226
                for p in module.parameters():
5227
                    test_case.assertIsInstance(p, torch.DoubleTensor)
5228
            # else: # for xpu
5229
            #     print()
5230
            #     if not isinstance(input, torch.xpu.LongTensor):
5231
            #         input = input.float()
5232
            #     module.float()
5233
            #     module(input)
5234
            #     for p in module.parameters():
5235
            #         test_case.assertIsInstance(p, torch.xpu.FloatTensor)
5236

5237
            #     # and back to double
5238
            #     if not isinstance(input, torch.xpu.LongTensor):
5239
            #         input = input.double()
5240
            #     module.double()
5241
            #     module(input)
5242
            #     for p in module.parameters():
5243
            #         test_case.assertIsInstance(p, torch.xpu.DoubleTensor)
5244

5245
            if TEST_CUDA and self.should_test_cuda:
5246
                # check that cuda() moves module parameters to correct GPU device,
5247
                # and that float() casts parameters correctly
5248

5249
                # to GPU0
5250
                input = input.float().cuda()
5251
                module.float().cuda()
5252
                module(input)
5253
                for p in module.parameters():
5254
                    test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5255
                    test_case.assertEqual(p.get_device(), 0)
5256

5257
                # to CPU
5258
                input = input.cpu()
5259
                module.cpu()
5260
                module(input)
5261
                for p in module.parameters():
5262
                    test_case.assertIsInstance(p, torch.FloatTensor)
5263

5264
                # back to GPU0
5265
                input = input.cuda()
5266
                module.cuda()
5267
                module(input)
5268
                for p in module.parameters():
5269
                    test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5270
                    test_case.assertEqual(p.get_device(), 0)
5271

5272
                # test that forwards of module runs correctly without cuDNN
5273
                if self.cudnn:
5274
                    with torch.backends.cudnn.flags(enabled=False):
5275
                        module(input)
5276
                        for p in module.parameters():
5277
                            test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5278
                            test_case.assertEqual(p.get_device(), 0)
5279

5280
                if torch.cuda.device_count() >= 2:
5281
                    # test cross-GPU transfer works
5282
                    # to GPU1
5283
                    input = input.cuda(1)
5284
                    module.cuda(1)
5285
                    with torch.cuda.device(1):
5286
                        module(input)
5287
                    for p in module.parameters():
5288
                        test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5289
                        test_case.assertEqual(p.get_device(), 1)
5290

5291
                if not self.skip_double:
5292
                    # test double()
5293
                    input = input.double().cuda()
5294
                    module.double().cuda()
5295
                    module(input)
5296
                    for p in module.parameters():
5297
                        test_case.assertIsInstance(p, torch.cuda.DoubleTensor)
5298
                        test_case.assertEqual(p.get_device(), 0)
5299

5300
                # test half()
5301
                input = input.half().cuda()
5302
                module.half().cuda()
5303
                module(input)
5304
                for p in module.parameters():
5305
                    test_case.assertIsInstance(p, torch.cuda.HalfTensor)
5306
                    test_case.assertEqual(p.get_device(), 0)
5307

5308
    def _get_target(self):
5309
        return self._get_arg("target", False)
5310

5311
    @property
5312
    def constructor_args(self):
5313
        return self._get_arg("constructor_args", False)
5314

5315

5316
class NewCriterionTest(InputVariableMixin, CriterionTest):
5317
    # TODO: check that criterions don't ignore grad_output
5318

5319
    def __init__(self, *args, **kwargs):
5320
        super(NewCriterionTest, self).__init__(*args, **kwargs)
5321
        self.check_gradgrad = kwargs.get("check_gradgrad", True)
5322
        self.check_half = kwargs.get("check_half", True)
5323
        self.check_bfloat16 = kwargs.get("check_bfloat16", False)
5324
        self.convert_target = kwargs.get("convert_target", True)
5325

5326
    def _do_extra_tests(self, test_case, module, input, target):
5327
        if not self.check_gradgrad:
5328
            return
5329

5330
        test_case.assertFalse(target.requires_grad)
5331

5332
        params = tuple(x for x in module.parameters())
5333
        if not isinstance(input, tuple):
5334
            inputs = (input,) + params
5335

5336
            def apply_fn(input, *params):
5337
                return module(input, target)
5338

5339
        else:
5340
            inputs = input + params
5341

5342
            def apply_fn(input1, input2, *params):
5343
                return module(input1, input2, target)
5344

5345
        # TODO: we don't pass `target` as part of inputs because we don't
5346
        # currently compute the gradient w.r.t. target for loss functions.
5347
        gradcheck(apply_fn, inputs)
5348
        gradgradcheck(apply_fn, inputs)
5349

5350
    def test_cuda(self, test_case, dtype=None, extra_args=None):
5351
        def convert_dtype(obj, dtype, requires_grad=False):
5352
            if isinstance(obj, torch.Tensor):
5353
                return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
5354
            elif isinstance(obj, torch.Tensor):
5355
                return obj.to(dtype)
5356
            elif isinstance(obj, tuple):
5357
                return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
5358
            else:
5359
                return obj
5360

5361
        if not TEST_CUDA or not self.should_test_cuda:
5362
            raise unittest.SkipTest("Excluded from CUDA tests")
5363
        try:
5364
            cpu_input = self._get_input()
5365
            cpu_target = self._get_target()
5366
            cpu_module = self.constructor(*self.constructor_args)
5367
            gpu_module = self.constructor(*self.constructor_args)
5368

5369
            # Convert input, target and module parameters to dtype
5370
            if dtype is not None:
5371
                cpu_input = convert_dtype(cpu_input, dtype, True)
5372
                # NLLLoss requires target to be LongTensor
5373
                if not isinstance(cpu_target, torch.LongTensor) and self.convert_target:
5374
                    cpu_target = convert_dtype(cpu_target, dtype)
5375
                cpu_module.type(dtype)
5376
                gpu_module.type(dtype)
5377

5378
            # GPU setup
5379
            gpu_input = to_gpu(cpu_input)
5380
            gpu_target = to_gpu(cpu_target)
5381
            gpu_module.cuda()
5382

5383
            # torch.HalfTensor doesn't support most operations, converting back to default
5384
            if dtype in {torch.half, torch.bfloat16}:
5385
                cpu_input = self._get_input()
5386
                cpu_target = self._get_target()
5387
                # Loss modules with weights require consistent input/module weight types
5388
                cpu_module = self.constructor(*self.constructor_args)
5389

5390
            cpu_output = test_case._forward_criterion(
5391
                cpu_module, cpu_input, cpu_target, extra_args=extra_args
5392
            )
5393
            gpu_output = test_case._forward_criterion(
5394
                gpu_module, gpu_input, gpu_target, extra_args=extra_args
5395
            )
5396
            # dtype can be None, so set precision in this way instead of a precision map
5397
            test_case.assertEqual(
5398
                cpu_output,
5399
                gpu_output,
5400
                1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4,
5401
            )
5402

5403
            cpu_gradInput = test_case._backward_criterion(
5404
                cpu_module, cpu_input, cpu_target, extra_args=extra_args
5405
            )
5406
            gpu_gradInput = test_case._backward_criterion(
5407
                gpu_module, gpu_input, gpu_target, extra_args=extra_args
5408
            )
5409
            test_case.assertEqual(
5410
                cpu_gradInput,
5411
                gpu_gradInput,
5412
                1e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4,
5413
            )
5414
        except NotImplementedError:
5415
            pass
5416

5417
    def _get_target(self):
5418
        return self._get_arg("target", False)
5419

5420
    @property
5421
    def constructor_args(self):
5422
        return self._get_arg("constructor_args", False)
5423

5424
    @property
5425
    def extra_args(self):
5426
        return self._get_arg("extra_args", False)
5427

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.