intel-extension-for-pytorch
5426 строк · 196.0 Кб
1"""
2From PyTorch:
3
4Copyright (c) 2016- Facebook, Inc (Adam Paszke)
5Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
6Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
7Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
8Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
9Copyright (c) 2011-2013 NYU (Clement Farabet)
10Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
11Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
12Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
13
14From Caffe2:
15
16Copyright (c) 2016-present, Facebook Inc. All rights reserved.
17
18All contributions by Facebook:
19Copyright (c) 2016 Facebook Inc.
20
21All contributions by Google:
22Copyright (c) 2015 Google Inc.
23All rights reserved.
24
25All contributions by Yangqing Jia:
26Copyright (c) 2015 Yangqing Jia
27All rights reserved.
28
29All contributions from Caffe:
30Copyright(c) 2013, 2014, 2015, the respective contributors
31All rights reserved.
32
33All other contributions:
34Copyright(c) 2015, 2016 the respective contributors
35All rights reserved.
36
37Caffe2 uses a copyright model similar to Caffe: each contributor holds
38copyright over their contributions to Caffe2. The project versioning records
39all such contribution and copyright details. If a contributor wants to further
40mark their specific copyright on a particular contribution, they should
41indicate their copyright solely in the commit message of the change when it is
42committed.
43
44All rights reserved.
45"""
46
47import math
48import sys
49import tempfile
50import unittest
51
52from copy import deepcopy
53from functools import reduce
54from itertools import product
55from operator import mul
56from math import pi
57
58
59import torch
60import torch.cuda
61import torch.nn as nn
62import torch.nn.functional as F
63from torch.nn.functional import _Reduction
64from common_utils import (
65TestCase,
66to_gpu,
67freeze_rng_state,
68is_iterable,
69TEST_WITH_ROCM,
70_assertGradAndGradgradChecks,
71)
72from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
73from torch.autograd import Variable
74import torch.backends.cudnn
75
76TEST_CUDA = torch.cuda.is_available()
77
78# tarfile module tries to obtain a file object name in python 3.3
79if sys.version_info[:2] == (3, 3):
80TemporaryFile = tempfile.NamedTemporaryFile
81else:
82TemporaryFile = tempfile.TemporaryFile
83PRECISION = 1e-5
84
85
86def get_reduction(m):
87result = getattr(m, "reduction", None)
88if result is None:
89result = _Reduction.legacy_get_string(
90getattr(m, "sizeAverage", None), True, emit_warning=False
91)
92assert result is not None
93return result
94
95
96def get_weight(m):
97result = getattr(m, "weight", None)
98if result is not None:
99return result
100return getattr(m, "weights", None)
101
102
103# NOTE [How to check NN module / functional API parity between Python and C++ frontends]
104#
105# The way to check API parity is to add parity tests for the NN module / functional of interest.
106# Here are the detailed steps:
107#
108# For NN module:
109# 1. Make sure you already have a test dict with the module configuration you want to test.
110# 2. Add `cpp_constructor_args` entry to the test dict, with its value exactly matching
111# the Python module constructor arguments. For example, if in the test dict we pass
112# `(10, 8)` to `torch.nn.Linear` constructor, then we should pass `torch::nn::LinearOptions(10, 8)`
113# as the corresponding C++ constructor argument to `torch::nn::Linear`.
114# 3. If in the process of performing the above step you referenced any variables
115# in the `cpp_constructor_args` entry, you must add `cpp_var_map` entry
116# to the test dict to make sure that those variables are populated with the right Python values.
117# For example, if the Python constructor call is
118# `torch.nn.FractionalMaxPool2d(2, output_ratio=0.5, _random_samples=random_samples)`,
119# the corresponding C++ constructor argument is
120# `torch::nn::FractionalMaxPool2dOptions(2).output_ratio(0.5)._random_samples(random_samples)`,
121# and the `cpp_var_map` entry must be
122# `{'random_samples': random_samples}` in order to populate the C++ variable `random_samples`
123# used in the C++ constructor argument with the Python tensor value `random_samples`.
124#
125# For NN functional:
126# 1. Make sure you already have a test dict with the functional configuration you want to test.
127# 2. If the test dict's `constructor` entry looks like `wrap_functional(F.some_functional_name, ...)`,
128# then you must add `cpp_options_args` entry to the test dict, with its value exactly matching the Python
129# functional optional arguments. For example, if the test dict's `constructor` entry is
130# `wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest')`,
131# then the `cpp_options_args` entry should be
132# "F::InterpolateFuncOptions().size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)".
133# 3. Otherwise, if the test dict's `constructor` entry looks like
134# `wrap_functional(lambda i: F.some_functional_name(...))`,
135# then you must add `cpp_function_call` entry to the test dict, with its value exactly matching the Python
136# functional function call. For example, if the test dict's `constructor` entry is
137# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
138# then the `cpp_function_call` entry should be
139# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
140# 4. If in the process of performing the above two steps you referenced any variables
141# in the `cpp_options_args` or `cpp_function_call` entry, you must
142# add `cpp_var_map` entry to the test dict to make sure that those variables
143# are populated with the right Python values. For example, if the test dict's `constructor` entry is
144# `wrap_functional(lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none'))`,
145# then the `cpp_function_call` entry should be
146# "F::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))".
147# Notice that there are two variables `i` and `t` that need to have their values provided,
148# and the way to do so is to add a `cpp_var_map` entry: `cpp_var_map={'i': '_get_input()', 't': t}`.
149# (Note that for `i`, since we want it to take the Python input value, we pass '_get_input()' string as value
150# and the C++ parity test mechanism will populate `i` with the Python input value correctly.)
151#
152# There are also a few optional flags in the test dict to control the C++ parity test behavior:
153#
154# - `test_cpp_api_parity`: if `False`, skips the C++ parity test for this test dict. Default: True.
155# - `has_parity`: if `False`, expects this test dict to fail the C++ parity test. Default: True.
156
157module_tests = [
158dict(
159module_name="Linear",
160constructor_args=(10, 8),
161cpp_constructor_args="torch::nn::LinearOptions(10, 8)",
162input_size=(4, 10),
163reference_fn=lambda i, p, _: torch.mm(i, p[0].t())
164+ p[1].view(1, -1).expand(4, 8),
165check_gradgrad=False,
166),
167dict(
168module_name="Linear",
169constructor_args=(10, 8, False),
170cpp_constructor_args="torch::nn::LinearOptions(10, 8).bias(false)",
171input_size=(4, 10),
172desc="no_bias",
173reference_fn=lambda i, p, _: torch.mm(i, p[0].t()),
174check_gradgrad=False,
175),
176dict(
177module_name="Threshold",
178constructor_args=(2.0, 1.0),
179cpp_constructor_args="torch::nn::ThresholdOptions(2., 1.)",
180input_size=(2, 3, 4, 5),
181check_inplace=True,
182desc="threshold_value",
183),
184dict(
185module_name="Threshold",
186constructor_args=(2.0, 10.0),
187cpp_constructor_args="torch::nn::ThresholdOptions(2., 10.)",
188input_size=(2, 3, 4, 5),
189desc="large_value",
190),
191dict(
192module_name="ReLU",
193input_size=(2, 3, 4, 5),
194check_inplace=True,
195),
196dict(
197module_name="ReLU6",
198input_size=(2, 3, 4, 5),
199check_inplace=True,
200),
201dict(
202module_name="RReLU",
203input_size=(1, 2, 2),
204test_cuda=False,
205),
206dict(
207module_name="RReLU",
208constructor_args=(0.1, 0.9),
209cpp_constructor_args="torch::nn::RReLUOptions().lower(0.1).upper(0.9)",
210input_size=(4, 4, 5),
211desc="with_up_down",
212test_cuda=False,
213),
214dict(
215module_name="Hardtanh",
216input_size=(3, 2, 5),
217reference_fn=lambda i, *_: i.clamp(-1, 1),
218),
219dict(
220module_name="Sigmoid",
221input_size=(2, 3, 4, 5),
222),
223dict(
224module_name="Tanh",
225input_size=(2, 3, 4, 5),
226),
227dict(
228module_name="Flatten",
229input_size=(2, 3, 4, 5),
230reference_fn=lambda i, *_: torch.flatten(i, 1),
231),
232dict(
233module_name="Softmax",
234constructor_args=(1,),
235cpp_constructor_args="torch::nn::SoftmaxOptions(1)",
236input_size=(10, 20),
237reference_fn=lambda i, *_: torch.exp(i).div(
238torch.exp(i).sum(1, True).expand(10, 20)
239),
240),
241dict(
242module_name="Softmax2d",
243input_size=(1, 3, 10, 20),
244reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(1, False)),
245),
246dict(
247module_name="LogSoftmax",
248constructor_args=(1,),
249cpp_constructor_args="torch::nn::LogSoftmaxOptions(1)",
250input_size=(10, 20),
251reference_fn=lambda i, *_: torch.exp(i)
252.div_(torch.exp(i).sum(1, True).expand(10, 20))
253.log_(),
254),
255dict(
256module_name="LogSoftmax",
257constructor_args=(1,),
258cpp_constructor_args="torch::nn::LogSoftmaxOptions(1)",
259input_size=(1, 3, 10, 20),
260reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
261desc="multiparam",
262),
263dict(
264module_name="ELU",
265constructor_args=(2.0,),
266cpp_constructor_args="torch::nn::ELUOptions().alpha(2.)",
267input_size=(3, 2, 5),
268reference_fn=lambda x, *_: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
269),
270# TODO: reference function
271dict(
272module_name="Hardshrink",
273constructor_args=(2.0,),
274cpp_constructor_args="torch::nn::HardshrinkOptions(2.)",
275input_size=(4, 3, 2, 4),
276),
277dict(module_name="LeakyReLU", input_size=(3, 2, 5), check_inplace=True),
278dict(
279module_name="LeakyReLU",
280constructor_args=(0.5,),
281cpp_constructor_args="torch::nn::LeakyReLUOptions().negative_slope(0.5)",
282input_size=(3, 2, 5),
283check_inplace=True,
284desc="with_negval",
285),
286dict(
287module_name="LogSigmoid",
288input_size=(2, 3, 4),
289reference_fn=lambda i, *_: i.sigmoid().log(),
290),
291dict(
292module_name="Softplus",
293input_size=(10, 20),
294reference_fn=lambda i, *_: torch.log(1 + torch.exp(i)),
295),
296dict(
297module_name="Softplus",
298constructor_args=(2,),
299cpp_constructor_args="torch::nn::SoftplusOptions().beta(2)",
300input_size=(10, 20),
301reference_fn=lambda i, *_: 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i)),
302desc="beta",
303),
304dict(
305module_name="Softplus",
306constructor_args=(2, -100),
307cpp_constructor_args="torch::nn::SoftplusOptions().beta(2).threshold(-100)",
308input_size=(10, 20),
309reference_fn=(
310lambda i, *_: ((i * 2) > -100).type_as(i) * i
311+ ((i * 2) <= -100).type_as(i) * 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i))
312),
313desc="beta_threshold",
314),
315dict(
316module_name="Softshrink",
317input_size=(3, 2, 5),
318),
319dict(
320module_name="Softshrink",
321constructor_args=(1,),
322cpp_constructor_args="torch::nn::SoftshrinkOptions(1)",
323input_size=(3, 2, 5),
324desc="lambda",
325),
326dict(
327module_name="CrossMapLRN2d",
328constructor_args=(5, 5e-3, 1e-3, 2),
329cpp_constructor_args="torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2)",
330input_size=(2, 3, 6, 6),
331check_gradgrad=False,
332),
333dict(
334module_name="PReLU",
335input_size=(2, 3, 4),
336reference_fn=lambda i, p, _: torch.clamp(i, min=0)
337+ torch.clamp(i, max=0) * p[0][0],
338desc="1d",
339),
340dict(
341module_name="PReLU",
342constructor_args=(3,),
343cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
344input_size=(2, 3, 4),
345desc="1d_multiparam",
346reference_fn=lambda i, p, _: torch.clamp(i, min=0)
347+ torch.clamp(i, max=0) * p[0][0],
348),
349dict(
350module_name="PReLU",
351input_size=(2, 3, 4, 5),
352desc="2d",
353reference_fn=lambda i, p, _: torch.clamp(i, min=0)
354+ torch.clamp(i, max=0) * p[0][0],
355),
356dict(
357module_name="PReLU",
358constructor_args=(3,),
359cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
360input_size=(2, 3, 4, 5),
361desc="2d_multiparam",
362reference_fn=lambda i, p, _: torch.clamp(i, min=0)
363+ torch.clamp(i, max=0) * p[0][0],
364),
365dict(
366module_name="PReLU",
367input_size=(2, 3, 4, 5, 6),
368reference_fn=lambda i, p, _: torch.clamp(i, min=0)
369+ torch.clamp(i, max=0) * p[0][0],
370desc="3d",
371),
372dict(
373module_name="PReLU",
374constructor_args=(3,),
375cpp_constructor_args="torch::nn::PReLUOptions().num_parameters(3)",
376input_size=(2, 3, 4, 5, 6),
377desc="3d_multiparam",
378reference_fn=lambda i, p, _: torch.clamp(i, min=0)
379+ torch.clamp(i, max=0) * p[0][0],
380),
381dict(
382module_name="Softsign",
383input_size=(3, 2, 5),
384reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
385),
386dict(
387module_name="Softmin",
388constructor_args=(1,),
389cpp_constructor_args="torch::nn::SoftminOptions(1)",
390input_size=(10, 20),
391),
392dict(
393module_name="Softmin",
394constructor_args=(1,),
395cpp_constructor_args="torch::nn::SoftminOptions(1)",
396input_size=(2, 3, 5, 10),
397desc="multidim",
398),
399dict(
400module_name="Tanhshrink",
401input_size=(2, 3, 4, 5),
402),
403]
404
405
406# Generates rand tensor with non-equal values. This ensures that duplicate
407# values won't be causing test failure for modules like MaxPooling.
408# size should be small, otherwise randperm fails / long overflows.
409def _rand_tensor_non_equal(*size):
410total = reduce(mul, size, 1)
411return torch.randperm(total).view(*size).double()
412
413
414def wrap_functional(fn, **kwargs):
415class FunctionalModule(nn.Module):
416def forward(self, *args):
417return fn(*args, **kwargs)
418
419return FunctionalModule
420
421
422def poissonnllloss_no_reduce_test():
423t = torch.randn(10, 10)
424return dict(
425fullname="PoissonNLLLoss_no_reduce",
426constructor=wrap_functional(
427lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction="none")
428),
429cpp_function_call="F::poisson_nll_loss("
430"i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))",
431input_fn=lambda: torch.rand(10, 10),
432cpp_var_map={"i": "_get_input()", "t": t},
433reference_fn=lambda i, *_: i.exp() - t.mul(i),
434pickle=False,
435)
436
437
438def bceloss_no_reduce_test():
439t = Variable(torch.randn(15, 10).gt(0).double())
440return dict(
441fullname="BCELoss_no_reduce",
442constructor=wrap_functional(
443lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction="none")
444),
445cpp_function_call="F::binary_cross_entropy("
446"i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
447input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
448cpp_var_map={"i": "_get_input()", "t": t},
449reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
450pickle=False,
451precision=7e-4,
452)
453
454
455def bceloss_no_reduce_scalar_test():
456t = torch.randn(()).gt(0).double()
457return dict(
458fullname="BCELoss_no_reduce_scalar",
459constructor=wrap_functional(
460lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction="none")
461),
462cpp_function_call="F::binary_cross_entropy("
463"i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))",
464input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
465cpp_var_map={"i": "_get_input()", "t": t},
466reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()),
467pickle=False,
468)
469
470
471def bceloss_weights_no_reduce_test():
472t = Variable(torch.randn(15, 10).gt(0).double())
473weights = torch.rand(10)
474return dict(
475fullname="BCELoss_weights_no_reduce",
476constructor=wrap_functional(
477lambda i: F.binary_cross_entropy(
478i, t.type_as(i), weight=weights.type_as(i), reduction="none"
479)
480),
481cpp_function_call="F::binary_cross_entropy("
482"i, t.to(i.options()), "
483"F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))",
484input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
485cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
486reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
487pickle=False,
488precision=3e-4,
489)
490
491
492def bceloss_weights_no_reduce_scalar_test():
493t = torch.randn(()).double()
494weights = torch.rand(())
495return dict(
496fullname="BCELoss_weights_no_reduce_scalar",
497constructor=wrap_functional(
498lambda i: F.binary_cross_entropy(
499i, t.type_as(i), weight=weights.type_as(i), reduction="none"
500)
501),
502cpp_function_call="""F::binary_cross_entropy(
503i, t.to(i.options()),
504F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
505cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
506input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
507reference_fn=lambda i, *_: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
508pickle=False,
509)
510
511
512def bce_with_logistic_legacy_enum_test():
513t = Variable(torch.randn(15, 10).gt(0).double())
514sigmoid = nn.Sigmoid()
515return dict(
516fullname="BCEWithLogitsLoss_legacy_enum",
517constructor=wrap_functional(
518lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)
519),
520cpp_function_call="""F::binary_cross_entropy_with_logits(
521i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
522input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
523cpp_var_map={"i": "_get_input()", "t": t},
524reference_fn=lambda i, *_: -(
525t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
526),
527check_gradgrad=False,
528pickle=False,
529)
530
531
532def bce_with_logistic_no_reduce_test():
533t = Variable(torch.randn(15, 10).gt(0).double())
534sigmoid = nn.Sigmoid()
535return dict(
536fullname="BCEWithLogitsLoss_no_reduce",
537constructor=wrap_functional(
538lambda i: F.binary_cross_entropy_with_logits(
539i, t.type_as(i), reduction="none"
540)
541),
542cpp_function_call="""F::binary_cross_entropy_with_logits(
543i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
544input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
545cpp_var_map={"i": "_get_input()", "t": t},
546reference_fn=lambda i, *_: -(
547t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
548),
549check_gradgrad=False,
550pickle=False,
551)
552
553
554def bce_with_logistic_no_reduce_scalar_test():
555t = torch.randn(()).gt(0).double()
556sigmoid = nn.Sigmoid()
557return dict(
558fullname="BCEWithLogitsLoss_no_reduce_scalar",
559constructor=wrap_functional(
560lambda i: F.binary_cross_entropy_with_logits(
561i, t.type_as(i), reduction="none"
562)
563),
564cpp_function_call="""F::binary_cross_entropy_with_logits(
565i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))""",
566input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
567cpp_var_map={"i": "_get_input()", "t": t},
568reference_fn=lambda i, *_: -(
569t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()
570),
571check_gradgrad=False,
572pickle=False,
573)
574
575
576def kldivloss_with_target_no_reduce_test():
577i = torch.rand(10, 10).log()
578
579return dict(
580fullname="KLDivLoss_with_target_no_reduce",
581constructor=wrap_functional(
582lambda t: F.kl_div(i.type_as(t), t, reduction="none")
583),
584cpp_function_call="F::kl_div(i.to(t.options()), t, F::KLDivFuncOptions().reduction(torch::kNone))",
585input_fn=lambda: torch.rand(10, 10),
586cpp_var_map={"i": i, "t": "_get_input()"},
587reference_fn=lambda t, *_: loss_reference_fns["KLDivLoss"](
588i.type_as(t), t, reduction="none"
589),
590pickle=False,
591)
592
593
594def kldivloss_no_reduce_test():
595t = torch.randn(10, 10)
596return dict(
597fullname="KLDivLoss_no_reduce",
598constructor=wrap_functional(
599lambda i: F.kl_div(i, t.type_as(i), reduction="none")
600),
601cpp_function_call="F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))",
602input_fn=lambda: torch.rand(10, 10).log(),
603cpp_var_map={"i": "_get_input()", "t": t},
604reference_fn=lambda i, *_: loss_reference_fns["KLDivLoss"](
605i, t.type_as(i), reduction="none"
606),
607pickle=False,
608)
609
610
611def kldivloss_no_reduce_scalar_test():
612t = torch.randn(())
613return dict(
614fullname="KLDivLoss_no_reduce_scalar",
615constructor=wrap_functional(
616lambda i: F.kl_div(i, t.type_as(i), reduction="none")
617),
618cpp_function_call="F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))",
619input_fn=lambda: torch.rand(()).log(),
620cpp_var_map={"i": "_get_input()", "t": t},
621reference_fn=lambda i, *_: loss_reference_fns["KLDivLoss"](
622i, t.type_as(i), reduction="none"
623),
624pickle=False,
625)
626
627
628def l1loss_no_reduce_test():
629t = torch.randn(2, 3, 4)
630return dict(
631fullname="L1Loss_no_reduce",
632constructor=wrap_functional(
633lambda i: F.l1_loss(i, t.type_as(i), reduction="none")
634),
635cpp_function_call="F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))",
636input_fn=lambda: torch.randn(2, 3, 4),
637cpp_var_map={"i": "_get_input()", "t": t},
638reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
639pickle=False,
640)
641
642
643def l1loss_no_reduce_scalar_test():
644t = torch.randn(())
645return dict(
646fullname="L1Loss_no_reduce_scalar",
647constructor=wrap_functional(
648lambda i: F.l1_loss(i, t.type_as(i), reduction="none")
649),
650cpp_function_call="F::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))",
651input_fn=lambda: torch.randn(()),
652cpp_var_map={"i": "_get_input()", "t": t},
653reference_fn=lambda i, *_: (i - t.type_as(i)).abs(),
654pickle=False,
655)
656
657
658def mseloss_no_reduce_test():
659input_size = (2, 3, 4, 5)
660target = torch.randn(*input_size)
661return dict(
662fullname="MSELoss_no_reduce",
663constructor=wrap_functional(
664lambda i: F.mse_loss(i, target.type_as(i), reduction="none")
665),
666cpp_function_call="F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))",
667input_size=input_size,
668cpp_var_map={"i": "_get_input()", "target": target},
669reference_fn=lambda i, *_: (i - target).pow(2),
670pickle=False,
671)
672
673
674def mseloss_no_reduce_scalar_test():
675input_size = ()
676target = torch.randn(input_size)
677return dict(
678fullname="MSELoss_no_reduce_scalar",
679constructor=wrap_functional(
680lambda i: F.mse_loss(i, target.type_as(i), reduction="none")
681),
682cpp_function_call="F::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))",
683input_size=input_size,
684cpp_var_map={"i": "_get_input()", "target": target},
685reference_fn=lambda i, *_: (i - target).pow(2),
686pickle=False,
687)
688
689
690def nllloss_no_reduce_test():
691t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
692kwargs = {"reduction": "none"}
693return dict(
694fullname="NLLLoss_no_reduce",
695constructor=wrap_functional(
696lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
697),
698cpp_function_call="""F::nll_loss(
699i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
700input_fn=lambda: torch.rand(15, 10).log(),
701cpp_var_map={"i": "_get_input()", "t": t},
702reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
703i, t.type_as(i).long(), **kwargs
704),
705pickle=False,
706)
707
708
709def nllloss_no_reduce_ignore_index_test():
710t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
711kwargs = {"ignore_index": 2, "reduction": "none"}
712return dict(
713fullname="NLLLoss_no_reduce_ignore_index",
714constructor=wrap_functional(
715lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
716),
717cpp_function_call="""F::nll_loss(
718i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))""",
719input_fn=lambda: torch.rand(15, 10).log(),
720cpp_var_map={"i": "_get_input()", "t": t},
721reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
722i, t.type_as(i).long(), **kwargs
723),
724pickle=False,
725)
726
727
728def nllloss_no_reduce_weights_test():
729t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
730weight = torch.rand(10)
731
732def kwargs(i):
733return {"weight": weight.type_as(i), "reduction": "none"}
734
735return dict(
736fullname="NLLLoss_no_reduce_weights",
737constructor=wrap_functional(
738lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
739),
740cpp_function_call="""F::nll_loss(
741i, t.to(i.options()).to(torch::kLong),
742F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
743input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
744cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
745reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
746i, t.type_as(i).long(), **kwargs(i)
747),
748pickle=False,
749)
750
751
752def nllloss_no_reduce_weights_ignore_index_test():
753t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
754weight = torch.rand(10)
755
756def kwargs(i):
757return {"weight": weight.type_as(i), "reduction": "none", "ignore_index": 2}
758
759return dict(
760fullname="NLLLoss_no_reduce_weights_ignore_index",
761constructor=wrap_functional(
762lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))
763),
764cpp_function_call="""F::nll_loss(
765i, t.to(i.options()).to(torch::kLong),
766F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))""",
767input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
768cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
769reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
770i, t.type_as(i).long(), **kwargs(i)
771),
772pickle=False,
773)
774
775
776def nllloss_no_reduce_weights_ignore_index_neg_test():
777t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
778weight = torch.rand(10)
779
780def kwargs(i):
781return {"weight": weight.type_as(i), "reduction": "none", "ignore_index": -1}
782
783return dict(
784fullname="NLLLoss_no_reduce_weights_ignore_index_neg",
785constructor=wrap_functional(
786lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
787),
788cpp_function_call="""F::nll_loss(
789i, t.to(i.options()).to(torch::kLong),
790F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))""",
791input=torch.rand(15, 10).add(1e-2).log(),
792cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
793reference_fn=lambda i, *_: loss_reference_fns["NLLLoss"](
794i, t.type_as(i).long(), **kwargs(i)
795),
796pickle=False,
797)
798
799
800def nllloss2d_no_reduce_test():
801t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
802kwargs = {"reduction": "none"}
803return dict(
804fullname="NLLLoss2d_no_reduce",
805constructor=wrap_functional(
806lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
807),
808cpp_function_call="""F::nll_loss(
809i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
810input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
811cpp_var_map={"i": "_get_input()", "t": t},
812reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
813i, t.type_as(i).long(), **kwargs
814),
815pickle=False,
816)
817
818
819def nllloss2d_no_reduce_ignore_index_test():
820t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
821kwargs = {"ignore_index": 1, "reduction": "none"}
822return dict(
823fullname="NLLLoss2d_no_reduce_ignore_index",
824constructor=wrap_functional(
825lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
826),
827cpp_function_call="""F::nll_loss(
828i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))""",
829input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
830cpp_var_map={"i": "_get_input()", "t": t},
831reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
832i, t.type_as(i).long(), **kwargs
833),
834pickle=False,
835)
836
837
838def nllloss2d_no_reduce_weights_test():
839t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
840weight = torch.rand(3)
841
842def kwargs(i):
843return {"weight": weight.type_as(i), "reduction": "none"}
844
845return dict(
846fullname="NLLLoss2d_no_reduce_weights",
847constructor=wrap_functional(
848lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
849),
850cpp_function_call="""F::nll_loss(
851i, t.to(i.options()).to(torch::kLong),
852F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
853input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
854cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
855reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
856i, t.type_as(i).long(), **kwargs(i)
857),
858pickle=False,
859)
860
861
862def nlllossNd_no_reduce_test():
863t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
864kwargs = {"reduction": "none"}
865return dict(
866fullname="NLLLossNd_no_reduce",
867constructor=wrap_functional(
868lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
869),
870cpp_function_call="""F::nll_loss(
871i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))""",
872input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
873cpp_var_map={"i": "_get_input()", "t": t},
874reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
875i, t.type_as(i).long(), **kwargs
876),
877pickle=False,
878)
879
880
881def nlllossNd_no_reduce_ignore_index_test():
882t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
883kwargs = {"ignore_index": 1, "reduction": "none"}
884return dict(
885fullname="NLLLossNd_no_reduce_ignore_index",
886constructor=wrap_functional(
887lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)
888),
889cpp_function_call="""F::nll_loss(
890i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))""",
891input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
892cpp_var_map={"i": "_get_input()", "t": t},
893reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
894i, t.type_as(i).long(), **kwargs
895),
896pickle=False,
897)
898
899
900def nlllossNd_no_reduce_weights_test():
901t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
902weight = torch.rand(3)
903
904def kwargs(i):
905return {"weight": weight.type_as(i), "reduction": "none"}
906
907return dict(
908fullname="NLLLossNd_no_reduce_weights",
909constructor=wrap_functional(
910lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))
911),
912cpp_function_call="""F::nll_loss(
913i, t.to(i.options()).to(torch::kLong),
914F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))""",
915input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
916cpp_var_map={"i": "_get_input()", "t": t, "weight": weight},
917reference_fn=lambda i, *_: loss_reference_fns["NLLLossNd"](
918i, t.type_as(i).long(), **kwargs(i)
919),
920pickle=False,
921)
922
923
924def smoothl1loss_no_reduce_test():
925t = torch.randn(2, 3, 4)
926return dict(
927fullname="SmoothL1Loss_no_reduce",
928constructor=wrap_functional(
929lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction="none")
930),
931cpp_function_call="""F::smooth_l1_loss(
932i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))""",
933input_fn=lambda: torch.randn(2, 3, 4),
934cpp_var_map={"i": "_get_input()", "t": t},
935reference_fn=lambda i, *_: loss_reference_fns["SmoothL1Loss"](
936i, t.type_as(i), reduction="none"
937),
938pickle=False,
939)
940
941
942def smoothl1loss_no_reduce_scalar_test():
943t = torch.randn(())
944return dict(
945fullname="SmoothL1Loss_no_reduce_scalar",
946constructor=wrap_functional(
947lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction="none")
948),
949cpp_function_call="""F::smooth_l1_loss(
950i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))""",
951input_fn=lambda: torch.randn(()),
952cpp_var_map={"i": "_get_input()", "t": t},
953reference_fn=lambda i, *_: loss_reference_fns["SmoothL1Loss"](
954i, t.type_as(i), reduction="none"
955),
956pickle=False,
957)
958
959
960def multilabelmarginloss_0d_no_reduce_test():
961t = torch.zeros(()).long()
962return dict(
963fullname="MultiLabelMarginLoss_0d_no_reduce",
964constructor=wrap_functional(
965lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
966),
967cpp_function_call="""F::multilabel_margin_loss(
968i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
969input_fn=lambda: torch.randn(()),
970cpp_var_map={"i": "_get_input()", "t": t},
971reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
972i, t.data.type_as(i).long(), reduction="none"
973),
974check_sum_reduction=True,
975check_gradgrad=False,
976pickle=False,
977)
978
979
980def multilabelmarginloss_1d_no_reduce_test():
981t = Variable(torch.rand(10).mul(10).floor().long())
982return dict(
983fullname="MultiLabelMarginLoss_1d_no_reduce",
984constructor=wrap_functional(
985lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
986),
987cpp_function_call="""F::multilabel_margin_loss(
988i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
989input_fn=lambda: torch.randn(10),
990cpp_var_map={"i": "_get_input()", "t": t},
991reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
992i, t.data.type_as(i).long(), reduction="none"
993),
994check_sum_reduction=True,
995check_gradgrad=False,
996pickle=False,
997)
998
999
1000def multilabelmarginloss_index_neg_test():
1001t = Variable(
1002torch.clamp(torch.rand(5, 10).add(-0.5).mul(20).floor().long(), min=-1)
1003)
1004return dict(
1005fullname="MultiLabelMarginLoss_index_neg",
1006constructor=wrap_functional(
1007lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
1008),
1009cpp_function_call="""F::multilabel_margin_loss(
1010i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
1011input_fn=lambda: torch.randn(5, 10),
1012cpp_var_map={"i": "_get_input()", "t": t},
1013reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
1014i, t.data.type_as(i).long(), reduction="none"
1015),
1016check_sum_reduction=True,
1017check_gradgrad=False,
1018pickle=False,
1019)
1020
1021
1022def multilabelmarginloss_no_reduce_test():
1023t = Variable(torch.rand(5, 10).mul(10).floor().long())
1024return dict(
1025fullname="MultiLabelMarginLoss_no_reduce",
1026constructor=wrap_functional(
1027lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction="none")
1028),
1029cpp_function_call="""F::multilabel_margin_loss(
1030i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))""",
1031input_fn=lambda: torch.randn(5, 10),
1032cpp_var_map={"i": "_get_input()", "t": t},
1033reference_fn=lambda i, *_: loss_reference_fns["MultiLabelMarginLoss"](
1034i, t.data.type_as(i).long(), reduction="none"
1035),
1036check_sum_reduction=True,
1037check_gradgrad=False,
1038pickle=False,
1039)
1040
1041
1042def hingeembeddingloss_no_reduce_test():
1043t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
1044return dict(
1045fullname="HingeEmbeddingLoss_no_reduce",
1046constructor=wrap_functional(
1047lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction="none")
1048),
1049cpp_function_call="""F::hinge_embedding_loss(
1050i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))""",
1051input_fn=lambda: torch.randn(10),
1052cpp_var_map={"i": "_get_input()", "t": t},
1053reference_fn=lambda i, *_: loss_reference_fns["HingeEmbeddingLoss"](
1054i, t.type_as(i), reduction="none"
1055),
1056check_sum_reduction=True,
1057pickle=False,
1058)
1059
1060
1061def hingeembeddingloss_margin_no_reduce_test():
1062t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
1063return dict(
1064fullname="HingeEmbeddingLoss_margin_no_reduce",
1065constructor=wrap_functional(
1066lambda i: F.hinge_embedding_loss(
1067i, t.type_as(i), margin=0.5, reduction="none"
1068)
1069),
1070cpp_function_call="""F::hinge_embedding_loss(
1071i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))""",
1072input_fn=lambda: torch.randn(10),
1073cpp_var_map={"i": "_get_input()", "t": t},
1074reference_fn=lambda i, *_: loss_reference_fns["HingeEmbeddingLoss"](
1075i, t.type_as(i), margin=0.5, reduction="none"
1076),
1077check_sum_reduction=True,
1078pickle=False,
1079)
1080
1081
1082def softmarginloss_no_reduce_test():
1083t = torch.randn(5, 5)
1084return dict(
1085fullname="SoftMarginLoss_no_reduce",
1086constructor=wrap_functional(
1087lambda i: F.soft_margin_loss(i, t.type_as(i), reduction="none")
1088),
1089cpp_function_call="""F::soft_margin_loss(
1090i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))""",
1091input_fn=lambda: torch.randn(5, 5),
1092cpp_var_map={"i": "_get_input()", "t": t},
1093reference_fn=lambda i, *_: loss_reference_fns["SoftMarginLoss"](
1094i, t.type_as(i), reduction="none"
1095),
1096pickle=False,
1097)
1098
1099
1100def multilabelsoftmarginloss_no_reduce_test():
1101t = torch.rand(5, 10).mul(2).floor()
1102return dict(
1103fullname="MultiLabelSoftMarginLoss_no_reduce",
1104constructor=wrap_functional(
1105lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction="none")
1106),
1107cpp_function_call="""F::multilabel_soft_margin_loss(
1108i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))""",
1109input_fn=lambda: torch.randn(5, 10),
1110cpp_var_map={"i": "_get_input()", "t": t},
1111reference_fn=lambda i, *_: (
1112-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())
1113).sum(dim=1)
1114/ i.size(1),
1115check_gradgrad=False,
1116pickle=False,
1117)
1118
1119
1120def multilabelsoftmarginloss_weights_no_reduce_test():
1121t = torch.rand(5, 10).mul(2).floor()
1122weights = torch.rand(10)
1123return dict(
1124fullname="MultiLabelSoftMarginLoss_weights_no_reduce",
1125constructor=wrap_functional(
1126lambda i: F.multilabel_soft_margin_loss(
1127i, t.type_as(i), weight=weights.type_as(i), reduction="none"
1128)
1129),
1130cpp_function_call="""F::multilabel_soft_margin_loss(
1131i, t.to(i.options()),
1132F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
1133input_fn=lambda: torch.randn(5, 10),
1134cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
1135reference_fn=lambda i, *_: (
1136-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights
1137).sum(dim=1)
1138/ i.size(1),
1139check_sum_reduction=True,
1140check_gradgrad=False,
1141pickle=False,
1142)
1143
1144
1145def multimarginloss_no_reduce_test():
1146t = torch.rand(5).mul(8).floor().long()
1147return dict(
1148fullname="MultiMarginLoss_no_reduce",
1149constructor=wrap_functional(
1150lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1151),
1152cpp_function_call="""F::multi_margin_loss(
1153i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1154input_fn=lambda: torch.randn(5, 10),
1155cpp_var_map={"i": "_get_input()", "t": t},
1156reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1157i, t.data.type_as(i).long(), reduction="none"
1158),
1159check_sum_reduction=True,
1160check_gradgrad=False,
1161pickle=False,
1162)
1163
1164
1165def multimarginloss_1d_no_reduce_test():
1166t = torch.rand(1).mul(8).floor().long()
1167return dict(
1168fullname="MultiMarginLoss_1d_no_reduce",
1169constructor=wrap_functional(
1170lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1171),
1172cpp_function_call="""F::multi_margin_loss(
1173i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1174input_fn=lambda: torch.randn(10),
1175cpp_var_map={"i": "_get_input()", "t": t},
1176reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1177i, t.data.type_as(i).long(), reduction="none"
1178),
1179check_sum_reduction=True,
1180check_gradgrad=False,
1181pickle=False,
1182)
1183
1184
1185def multimarginloss_1d_input_0d_target_no_reduce_test():
1186t = torch.rand(()).mul(8).floor().long()
1187return dict(
1188fullname="multimarginloss_1d_input_0d_target_no_reduce",
1189constructor=wrap_functional(
1190lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction="none")
1191),
1192cpp_function_call="""F::multi_margin_loss(
1193i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))""",
1194input_fn=lambda: torch.randn(10),
1195cpp_var_map={"i": "_get_input()", "t": t},
1196reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1197i, t.data.type_as(i).long(), reduction="none"
1198),
1199check_sum_reduction=True,
1200check_gradgrad=False,
1201pickle=False,
1202)
1203
1204
1205def multimarginloss_p_no_reduce_test():
1206t = torch.rand(5).mul(8).floor().long()
1207return dict(
1208fullname="MultiMarginLoss_p_no_reduce",
1209constructor=wrap_functional(
1210lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction="none")
1211),
1212cpp_function_call="""F::multi_margin_loss(
1213i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))""",
1214input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
1215cpp_var_map={"i": "_get_input()", "t": t},
1216reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1217i, t.data.type_as(i).long(), p=2, reduction="none"
1218),
1219check_sum_reduction=True,
1220check_gradgrad=False,
1221pickle=False,
1222)
1223
1224
1225def multimarginloss_margin_no_reduce_test():
1226t = torch.rand(5).mul(8).floor().long()
1227return dict(
1228fullname="MultiMarginLoss_margin_no_reduce",
1229constructor=wrap_functional(
1230lambda i: F.multi_margin_loss(
1231i, t.type_as(i).long(), margin=0.5, reduction="none"
1232)
1233),
1234cpp_function_call="""F::multi_margin_loss(
1235i, t.to(i.options()).to(torch::kLong),
1236F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))""",
1237input_fn=lambda: torch.randn(5, 10),
1238cpp_var_map={"i": "_get_input()", "t": t},
1239reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1240i, t.data.type_as(i).long(), margin=0.5, reduction="none"
1241),
1242check_sum_reduction=True,
1243check_gradgrad=False,
1244pickle=False,
1245)
1246
1247
1248def multimarginloss_weights_no_reduce_test():
1249t = torch.rand(5).mul(8).floor().long()
1250weights = torch.rand(10)
1251return dict(
1252fullname="MultiMarginLoss_weights_no_reduce",
1253constructor=wrap_functional(
1254lambda i: F.multi_margin_loss(
1255i, t.type_as(i).long(), weight=weights.type_as(i), reduction="none"
1256)
1257),
1258cpp_function_call="""F::multi_margin_loss(
1259i, t.to(i.options()).to(torch::kLong),
1260F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))""",
1261input_fn=lambda: torch.randn(5, 10),
1262cpp_var_map={"i": "_get_input()", "t": t, "weights": weights},
1263reference_fn=lambda i, *_: loss_reference_fns["MultiMarginLoss"](
1264i, t.data.type_as(i).long(), weight=weights, reduction="none"
1265),
1266check_sum_reduction=True,
1267check_gradgrad=False,
1268pickle=False,
1269)
1270
1271
1272def fractional_max_pool2d_test(test_case):
1273random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
1274if test_case == "ratio":
1275return dict(
1276constructor=lambda: nn.FractionalMaxPool2d(
12772, output_ratio=0.5, _random_samples=random_samples
1278),
1279cpp_constructor_args="""torch::nn::FractionalMaxPool2dOptions(2)
1280.output_ratio(0.5)
1281._random_samples(random_samples)""",
1282input_size=(1, 3, 5, 7),
1283cpp_var_map={"random_samples": random_samples},
1284fullname="FractionalMaxPool2d_ratio",
1285)
1286elif test_case == "size":
1287return dict(
1288constructor=lambda: nn.FractionalMaxPool2d(
1289(2, 3), output_size=(4, 3), _random_samples=random_samples
1290),
1291cpp_constructor_args="""torch::nn::FractionalMaxPool2dOptions({2, 3})
1292.output_size(std::vector<int64_t>({4, 3}))
1293._random_samples(random_samples)""",
1294input_size=(1, 3, 7, 6),
1295cpp_var_map={"random_samples": random_samples},
1296fullname="FractionalMaxPool2d_size",
1297)
1298
1299
1300def fractional_max_pool3d_test(test_case):
1301random_samples = torch.DoubleTensor(2, 4, 3).uniform_()
1302if test_case == "ratio":
1303return dict(
1304constructor=lambda: nn.FractionalMaxPool3d(
13052, output_ratio=0.5, _random_samples=random_samples
1306),
1307cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions(2)
1308.output_ratio(0.5)
1309._random_samples(random_samples)""",
1310input_size=(2, 4, 5, 5, 5),
1311cpp_var_map={"random_samples": random_samples},
1312fullname="FractionalMaxPool3d_ratio",
1313)
1314elif test_case == "size":
1315return dict(
1316constructor=lambda: nn.FractionalMaxPool3d(
1317(2, 2, 2), output_size=(4, 4, 4), _random_samples=random_samples
1318),
1319cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions({2, 2, 2})
1320.output_size(std::vector<int64_t>({4, 4, 4}))
1321._random_samples(random_samples)""",
1322input_size=(2, 4, 7, 7, 7),
1323cpp_var_map={"random_samples": random_samples},
1324fullname="FractionalMaxPool3d_size",
1325)
1326elif test_case == "asymsize":
1327return dict(
1328constructor=lambda: nn.FractionalMaxPool3d(
1329(4, 2, 3), output_size=(10, 3, 2), _random_samples=random_samples
1330),
1331cpp_constructor_args="""torch::nn::FractionalMaxPool3dOptions({4, 2, 3})
1332.output_size(std::vector<int64_t>({10, 3, 2}))
1333._random_samples(random_samples)""",
1334input_size=(2, 4, 16, 7, 5),
1335cpp_var_map={"random_samples": random_samples},
1336fullname="FractionalMaxPool3d_asymsize",
1337)
1338
1339
1340new_module_tests = [
1341poissonnllloss_no_reduce_test(),
1342bceloss_no_reduce_test(),
1343bceloss_weights_no_reduce_test(),
1344bce_with_logistic_legacy_enum_test(),
1345bce_with_logistic_no_reduce_test(),
1346bceloss_no_reduce_scalar_test(),
1347bceloss_weights_no_reduce_scalar_test(),
1348bce_with_logistic_no_reduce_scalar_test(),
1349kldivloss_with_target_no_reduce_test(),
1350kldivloss_no_reduce_test(),
1351kldivloss_no_reduce_scalar_test(),
1352l1loss_no_reduce_test(),
1353l1loss_no_reduce_scalar_test(),
1354mseloss_no_reduce_test(),
1355mseloss_no_reduce_scalar_test(),
1356nllloss_no_reduce_test(),
1357nllloss_no_reduce_ignore_index_test(),
1358nllloss_no_reduce_weights_test(),
1359nllloss_no_reduce_weights_ignore_index_test(),
1360nllloss_no_reduce_weights_ignore_index_neg_test(),
1361nllloss2d_no_reduce_test(),
1362nllloss2d_no_reduce_weights_test(),
1363nllloss2d_no_reduce_ignore_index_test(),
1364nlllossNd_no_reduce_test(),
1365nlllossNd_no_reduce_weights_test(),
1366nlllossNd_no_reduce_ignore_index_test(),
1367smoothl1loss_no_reduce_test(),
1368smoothl1loss_no_reduce_scalar_test(),
1369multilabelmarginloss_0d_no_reduce_test(),
1370multilabelmarginloss_1d_no_reduce_test(),
1371multilabelmarginloss_index_neg_test(),
1372multilabelmarginloss_no_reduce_test(),
1373hingeembeddingloss_no_reduce_test(),
1374hingeembeddingloss_margin_no_reduce_test(),
1375softmarginloss_no_reduce_test(),
1376multilabelsoftmarginloss_no_reduce_test(),
1377multilabelsoftmarginloss_weights_no_reduce_test(),
1378multimarginloss_no_reduce_test(),
1379multimarginloss_1d_no_reduce_test(),
1380multimarginloss_1d_input_0d_target_no_reduce_test(),
1381multimarginloss_p_no_reduce_test(),
1382multimarginloss_margin_no_reduce_test(),
1383multimarginloss_weights_no_reduce_test(),
1384fractional_max_pool2d_test("ratio"),
1385fractional_max_pool2d_test("size"),
1386fractional_max_pool3d_test("ratio"),
1387fractional_max_pool3d_test("size"),
1388fractional_max_pool3d_test("asymsize"),
1389dict(
1390module_name="BatchNorm1d",
1391constructor_args=(10,),
1392cpp_constructor_args="torch::nn::BatchNorm1dOptions(10)",
1393input_size=(4, 10),
1394cudnn=True,
1395check_eval=True,
1396desc="affine",
1397test_cuda=(not TEST_WITH_ROCM),
1398pickle=False,
1399),
1400dict(
1401module_name="BatchNorm1d",
1402constructor_args=(5,),
1403cpp_constructor_args="torch::nn::BatchNorm1dOptions(5)",
1404input_size=(4, 5, 3),
1405cudnn=True,
1406check_eval=True,
1407desc="3d_input",
1408pickle=False,
1409),
1410dict(
1411module_name="BatchNorm1d",
1412constructor_args=(10, 1e-3, None),
1413cpp_constructor_args="torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(c10::nullopt)",
1414input_size=(4, 10),
1415cudnn=True,
1416check_eval=True,
1417desc="affine_simple_average",
1418test_cuda=(not TEST_WITH_ROCM),
1419pickle=False,
1420),
1421dict(
1422module_name="BatchNorm1d",
1423constructor_args=(10, 1e-3, 0.3, False),
1424cpp_constructor_args="torch::nn::BatchNorm1dOptions(10).eps(1e-3).momentum(0.3).affine(false)",
1425input_size=(4, 10),
1426cudnn=True,
1427check_eval=True,
1428desc="not_affine",
1429pickle=False,
1430),
1431dict(
1432module_name="BatchNorm1d",
1433constructor_args=(10, 1e-3, 0.3, True, False),
1434cpp_constructor_args="""torch::nn::BatchNorm1dOptions(10)
1435.eps(1e-3).momentum(0.3).affine(true).track_running_stats(false)""",
1436input_size=(4, 10),
1437cudnn=True,
1438check_eval=True,
1439desc="not_tracking_stats",
1440test_cuda=(not TEST_WITH_ROCM),
1441pickle=False,
1442),
1443dict(
1444module_name="BatchNorm1d",
1445constructor_args=(5, 1e-3, 0.3, False),
1446cpp_constructor_args="torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1447input_size=(4, 5, 3),
1448cudnn=True,
1449check_eval=True,
1450desc="3d_input_not_affine",
1451pickle=False,
1452),
1453dict(
1454module_name="BatchNorm1d",
1455constructor_args=(5, 1e-3, 0.3, False),
1456cpp_constructor_args="torch::nn::BatchNorm1dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1457input_size=(0, 5, 9),
1458cudnn=True,
1459check_eval=True,
1460desc="zero_batch",
1461pickle=False,
1462),
1463dict(
1464module_name="BatchNorm2d",
1465constructor_args=(3,),
1466cpp_constructor_args="torch::nn::BatchNorm2dOptions(3)",
1467input_size=(2, 3, 6, 6),
1468cudnn=True,
1469check_eval=True,
1470pickle=False,
1471),
1472dict(
1473module_name="BatchNorm2d",
1474constructor_args=(3, 1e-3, None),
1475cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(c10::nullopt)",
1476input_size=(2, 3, 6, 6),
1477cudnn=True,
1478check_eval=True,
1479desc="2d_simple_average",
1480pickle=False,
1481),
1482dict(
1483module_name="BatchNorm2d",
1484constructor_args=(3, 1e-3, 0.8),
1485cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8)",
1486input_size=(2, 3, 6, 6),
1487cudnn=True,
1488check_eval=True,
1489desc="momentum",
1490pickle=False,
1491),
1492dict(
1493module_name="BatchNorm2d",
1494constructor_args=(3, 1e-3, 0.8, False),
1495cpp_constructor_args="torch::nn::BatchNorm2dOptions(3).eps(1e-3).momentum(0.8).affine(false)",
1496input_size=(2, 3, 6, 6),
1497cudnn=True,
1498check_eval=True,
1499desc="not_affine",
1500pickle=False,
1501),
1502dict(
1503module_name="BatchNorm2d",
1504constructor_args=(3, 1e-3, 0.8, True, False),
1505cpp_constructor_args="""torch::nn::BatchNorm2dOptions(3)
1506.eps(1e-3).momentum(0.8).affine(true).track_running_stats(false)""",
1507input_size=(2, 3, 6, 6),
1508cudnn=True,
1509check_eval=True,
1510desc="not_tracking_stats",
1511pickle=False,
1512),
1513dict(
1514module_name="BatchNorm2d",
1515constructor_args=(5, 1e-3, 0.3, False),
1516cpp_constructor_args="torch::nn::BatchNorm2dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1517input_size=(0, 5, 2, 2),
1518cudnn=True,
1519check_eval=True,
1520desc="zero_batch",
1521pickle=False,
1522),
1523dict(
1524module_name="BatchNorm3d",
1525constructor_args=(3,),
1526cpp_constructor_args="torch::nn::BatchNorm3dOptions(3)",
1527input_size=(2, 3, 4, 4, 4),
1528cudnn=True,
1529check_eval=True,
1530pickle=False,
1531),
1532dict(
1533module_name="BatchNorm3d",
1534constructor_args=(3, 1e-3, None),
1535cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(c10::nullopt)",
1536input_size=(2, 3, 4, 4, 4),
1537cudnn=True,
1538check_eval=True,
1539desc="3d_simple_average",
1540pickle=False,
1541),
1542dict(
1543module_name="BatchNorm3d",
1544constructor_args=(3, 1e-3, 0.7),
1545cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7)",
1546input_size=(2, 3, 4, 4, 4),
1547cudnn=True,
1548check_eval=True,
1549desc="momentum",
1550pickle=False,
1551),
1552dict(
1553module_name="BatchNorm3d",
1554constructor_args=(3, 1e-3, 0.7, False),
1555cpp_constructor_args="torch::nn::BatchNorm3dOptions(3).eps(1e-3).momentum(0.7).affine(false)",
1556input_size=(2, 3, 4, 4, 4),
1557cudnn=True,
1558check_eval=True,
1559desc="not_affine",
1560pickle=False,
1561),
1562dict(
1563module_name="BatchNorm3d",
1564constructor_args=(3, 1e-3, 0.7, True, False),
1565cpp_constructor_args="""torch::nn::BatchNorm3dOptions(3)
1566.eps(1e-3).momentum(0.7).affine(true).track_running_stats(false)""",
1567input_size=(2, 3, 4, 4, 4),
1568cudnn=True,
1569check_eval=True,
1570desc="not_tracking_stats",
1571pickle=False,
1572),
1573dict(
1574module_name="BatchNorm3d",
1575constructor_args=(5, 1e-3, 0.3, False),
1576cpp_constructor_args="torch::nn::BatchNorm3dOptions(5).eps(1e-3).momentum(0.3).affine(false)",
1577input_size=(0, 5, 2, 2, 2),
1578cudnn=True,
1579check_eval=True,
1580desc="zero_batch",
1581pickle=False,
1582),
1583dict(
1584module_name="InstanceNorm1d",
1585constructor_args=(3, 1e-3, 0.3),
1586cpp_constructor_args="torch::nn::InstanceNorm1dOptions(3).eps(1e-3).momentum(0.3)",
1587input_size=(4, 3, 15),
1588cudnn=True,
1589check_eval=True,
1590pickle=False,
1591),
1592dict(
1593module_name="InstanceNorm1d",
1594constructor_args=(3, 1e-3, 0.3, False, True),
1595cpp_constructor_args="""torch::nn::InstanceNorm1dOptions(3)
1596.eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1597input_size=(4, 3, 15),
1598cudnn=True,
1599check_eval=True,
1600desc="tracking_stats",
1601pickle=False,
1602),
1603dict(
1604module_name="InstanceNorm2d",
1605constructor_args=(3, 1e-3, 0.3),
1606cpp_constructor_args="torch::nn::InstanceNorm2dOptions(3).eps(1e-3).momentum(0.3)",
1607input_size=(2, 3, 6, 6),
1608cudnn=True,
1609check_eval=True,
1610pickle=False,
1611),
1612dict(
1613module_name="InstanceNorm2d",
1614constructor_args=(3, 1e-3, 0.3, False, True),
1615cpp_constructor_args="""torch::nn::InstanceNorm2dOptions(3)
1616.eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1617input_size=(2, 3, 6, 6),
1618cudnn=True,
1619check_eval=True,
1620desc="tracking_stats",
1621pickle=False,
1622),
1623dict(
1624module_name="InstanceNorm3d",
1625constructor_args=(3, 1e-3, 0.3),
1626cpp_constructor_args="torch::nn::InstanceNorm3dOptions(3).eps(1e-3).momentum(0.3)",
1627input_size=(2, 3, 4, 4, 4),
1628cudnn=True,
1629check_eval=True,
1630pickle=False,
1631),
1632dict(
1633module_name="InstanceNorm3d",
1634constructor_args=(3, 1e-3, 0.3, False, True),
1635cpp_constructor_args="""torch::nn::InstanceNorm3dOptions(3)
1636.eps(1e-3).momentum(0.3).affine(false).track_running_stats(true)""",
1637input_size=(2, 3, 4, 4, 4),
1638cudnn=True,
1639check_eval=True,
1640desc="tracking_stats",
1641pickle=False,
1642),
1643dict(
1644module_name="LayerNorm",
1645constructor_args=([5], 1e-3),
1646cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3)",
1647input_size=(4, 5, 5),
1648cudnn=True,
1649check_eval=True,
1650desc="1d_elementwise_affine",
1651),
1652dict(
1653module_name="LayerNorm",
1654constructor_args=([5], 1e-3, False),
1655cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3).elementwise_affine(false)",
1656input_size=(4, 5, 5),
1657cudnn=True,
1658check_eval=True,
1659desc="1d_no_elementwise_affine",
1660),
1661dict(
1662module_name="LayerNorm",
1663constructor_args=([2, 2, 5], 1e-3),
1664cpp_constructor_args="torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3)",
1665input_size=(4, 2, 2, 5),
1666cudnn=True,
1667check_eval=True,
1668desc="3d_elementwise_affine",
1669),
1670dict(
1671module_name="LayerNorm",
1672constructor_args=([2, 2, 5], 1e-3, False),
1673cpp_constructor_args="torch::nn::LayerNormOptions({2, 2, 5}).eps(1e-3).elementwise_affine(false)",
1674input_size=(4, 2, 2, 5),
1675cudnn=True,
1676check_eval=True,
1677desc="3d_no_elementwise_affine",
1678),
1679dict(
1680module_name="LayerNorm",
1681constructor_args=([5], 1e-3),
1682cpp_constructor_args="torch::nn::LayerNormOptions({5}).eps(1e-3)",
1683input_size=(0, 5),
1684cudnn=True,
1685check_eval=True,
1686desc="1d_empty_elementwise_affine",
1687),
1688dict(
1689module_name="GroupNorm",
1690constructor_args=(3, 6, 1e-3),
1691cpp_constructor_args="torch::nn::GroupNormOptions(3, 6).eps(1e-3)",
1692input_size=(4, 6, 5),
1693cudnn=True,
1694check_eval=True,
1695desc="1d_affine",
1696),
1697dict(
1698module_name="GroupNorm",
1699constructor_args=(5, 5, 1e-3, False),
1700cpp_constructor_args="torch::nn::GroupNormOptions(5, 5).eps(1e-3).affine(false)",
1701input_size=(4, 5, 5),
1702cudnn=True,
1703check_eval=True,
1704desc="1d_no_affine_IN", # this setting is equivalent with InstanceNormi
1705),
1706dict(
1707module_name="GroupNorm",
1708constructor_args=(1, 5, 1e-3, False),
1709cpp_constructor_args="torch::nn::GroupNormOptions(1, 5).eps(1e-3).affine(false)",
1710input_size=(4, 5, 5),
1711cudnn=True,
1712check_eval=True,
1713desc="1d_no_affine_LN", # this setting is equivalent with LayerNorm
1714),
1715dict(
1716module_name="GroupNorm",
1717constructor_args=(3, 6, 1e-3),
1718cpp_constructor_args="torch::nn::GroupNormOptions(3, 6).eps(1e-3)",
1719input_size=(4, 6, 2, 3),
1720cudnn=True,
1721check_eval=True,
1722desc="2d_affine",
1723),
1724dict(
1725module_name="GroupNorm",
1726constructor_args=(3, 3, 1e-3, False),
1727cpp_constructor_args="torch::nn::GroupNormOptions(3, 3).eps(1e-3).affine(false)",
1728input_size=(4, 3, 2, 3),
1729cudnn=True,
1730check_eval=True,
1731desc="2d_no_affine_IN", # this setting is equivalent with InstanceNorm
1732),
1733dict(
1734module_name="GroupNorm",
1735constructor_args=(1, 3, 1e-3, False),
1736cpp_constructor_args="torch::nn::GroupNormOptions(1, 3).eps(1e-3).affine(false)",
1737input_size=(4, 3, 2, 3),
1738cudnn=True,
1739check_eval=True,
1740desc="2d_no_affine_LN", # this setting is equivalent with LayerNorm
1741),
1742dict(
1743module_name="Conv1d",
1744constructor_args=(4, 5, 3),
1745cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3)",
1746input_size=(2, 4, 10),
1747cudnn=True,
1748),
1749dict(
1750module_name="Conv1d",
1751constructor_args=(4, 5, 3, 2),
1752cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).stride(2)",
1753input_size=(2, 4, 10),
1754cudnn=True,
1755desc="stride",
1756),
1757dict(
1758module_name="Conv1d",
1759constructor_args=(4, 5, 3, 1, 1),
1760cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)",
1761input_size=(2, 4, 10),
1762cudnn=True,
1763desc="pad1",
1764),
1765dict(
1766module_name="Conv1d",
1767constructor_args=(4, 5, 5, 1, 2),
1768cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)",
1769input_size=(2, 4, 10),
1770cudnn=True,
1771desc="pad2",
1772),
1773dict(
1774module_name="Conv1d",
1775constructor_args=(4, 4, 3, 1, 1),
1776cpp_constructor_args="torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1)",
1777input_size=(1, 4, 1),
1778cudnn=True,
1779desc="pad1size1",
1780),
1781dict(
1782module_name="Conv1d",
1783constructor_args=(4, 4, 5, 1, 2),
1784cpp_constructor_args="torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)",
1785input_size=(1, 4, 1),
1786cudnn=True,
1787desc="pad2size1",
1788),
1789dict(
1790module_name="Conv1d",
1791constructor_args=(4, 5, 3),
1792cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3)",
1793input_size=(0, 4, 10),
1794cudnn=True,
1795desc="zero_batch",
1796test_cuda=(not TEST_WITH_ROCM),
1797),
1798dict(
1799fullname="Conv1d_dilated",
1800constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
1801cpp_constructor_args="torch::nn::Conv1dOptions(4, 5, 3).dilation(2)",
1802input_size=(2, 4, 10),
1803),
1804dict(
1805fullname="Conv1d_groups",
1806constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
1807cpp_constructor_args="torch::nn::Conv1dOptions(4, 6, 3).groups(2)",
1808input_size=(2, 4, 6),
1809cudnn=True,
1810),
1811dict(
1812fullname="ConvTranspose1d",
1813constructor=lambda: nn.ConvTranspose1d(
18143, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)
1815),
1816cpp_constructor_args="torch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1)",
1817cudnn=True,
1818input_size=(1, 3, 7),
1819),
1820dict(
1821module_name="ConvTranspose1d",
1822constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
1823cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(3, 4, 3)
1824.stride(2).padding(1).output_padding(1).groups(1).bias(false)""",
1825input_size=(1, 3, 6),
1826cudnn=True,
1827desc="no_bias",
1828),
1829dict(
1830module_name="ConvTranspose1d",
1831constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
1832cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(3, 4, 3)
1833.stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)""",
1834input_size=(1, 3, 6),
1835cudnn=True,
1836desc="dilated",
1837),
1838dict(
1839fullname="ConvTranspose1d_groups",
1840constructor=lambda: nn.ConvTranspose1d(
18414, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2
1842),
1843cpp_constructor_args="""torch::nn::ConvTranspose1dOptions(4, 6, 3)
1844.stride(3).padding(1).output_padding(1).groups(2)""",
1845cudnn=True,
1846input_size=(2, 4, 7),
1847),
1848dict(
1849module_name="MaxPool1d",
1850constructor_args=(4,),
1851cpp_constructor_args="torch::nn::MaxPool1dOptions(4)",
1852input_size=(2, 10, 4),
1853),
1854dict(
1855module_name="MaxPool1d",
1856constructor_args=(4, 4),
1857cpp_constructor_args="torch::nn::MaxPool1dOptions(4).stride(4)",
1858input_size=(2, 10, 4),
1859desc="stride",
1860),
1861dict(
1862module_name="Conv2d",
1863constructor_args=(3, 4, (3, 2)),
1864cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 2})",
1865input_size=(2, 3, 7, 5),
1866cudnn=True,
1867check_with_long_tensor=True,
1868),
1869dict(
1870module_name="Conv2d",
1871constructor_args=(3, 4, (3, 3), (2, 2)),
1872cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})",
1873input_size=(2, 3, 6, 6),
1874cudnn=True,
1875desc="strided",
1876check_with_long_tensor=True,
1877),
1878dict(
1879module_name="Conv2d",
1880constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
1881cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})",
1882input_size=(2, 3, 6, 6),
1883cudnn=True,
1884desc="padding",
1885check_with_long_tensor=True,
1886),
1887dict(
1888module_name="Conv2d",
1889constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
1890cpp_constructor_args="torch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2})",
1891input_size=(2, 3, 8, 8),
1892cudnn=True,
1893desc="dilated",
1894check_with_long_tensor=True,
1895),
1896dict(
1897module_name="Conv2d",
1898constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
1899cpp_constructor_args="""torch::nn::Conv2dOptions(3, 4, {3, 2})
1900.stride(1).padding(0).dilation(1).groups(1).bias(false)""",
1901input_size=(2, 3, 6, 5),
1902cudnn=True,
1903desc="no_bias",
1904check_with_long_tensor=True,
1905),
1906dict(
1907module_name="Conv2d",
1908constructor_args=(3, 4, (3, 2)),
1909cpp_constructor_args="torch::nn::Conv2dOptions(3, 4, {3, 2})",
1910input_size=(0, 3, 7, 5),
1911cudnn=True,
1912desc="zero_batch",
1913check_with_long_tensor=True,
1914test_cuda=(not TEST_WITH_ROCM),
1915),
1916dict(
1917fullname="Conv2d_groups",
1918constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1919cpp_constructor_args="torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)",
1920input_size=(2, 4, 6, 5),
1921cudnn=True,
1922check_with_long_tensor=True,
1923),
1924dict(
1925fullname="Conv2d_groups_thnn",
1926constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1927cpp_constructor_args="torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2)",
1928input_size=(2, 4, 6, 5),
1929check_with_long_tensor=True,
1930),
1931dict(
1932module_name="ConvTranspose2d",
1933constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
1934cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1935.stride({3, 2}).padding(1).output_padding({1, 1})""",
1936cudnn=True,
1937input_size=(1, 3, 7, 6),
1938check_with_long_tensor=True,
1939),
1940dict(
1941module_name="ConvTranspose2d",
1942constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
1943cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1944.stride({2, 3})
1945.padding(1)
1946.output_padding({1, 1})
1947.groups(1)
1948.bias(false)
1949.dilation({2, 2})""",
1950input_size=(1, 3, 6, 7),
1951cudnn=True,
1952desc="dilated",
1953check_with_long_tensor=True,
1954),
1955dict(
1956module_name="ConvTranspose2d",
1957constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
1958cpp_constructor_args="""torch::nn::ConvTranspose2dOptions(3, 4, 3)
1959.stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)""",
1960input_size=(1, 3, 6, 7),
1961cudnn=True,
1962desc="no_bias",
1963check_with_long_tensor=True,
1964),
1965dict(
1966fullname="ConvTranspose2d_groups",
1967constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
1968cpp_constructor_args="torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2)",
1969input_size=(1, 2, 4, 5),
1970cudnn=True,
1971check_with_long_tensor=True,
1972),
1973dict(
1974fullname="Conv2d_depthwise",
1975constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
1976cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4)",
1977input_size=(2, 4, 6, 6),
1978),
1979dict(
1980fullname="Conv2d_depthwise_with_multiplier",
1981constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
1982cpp_constructor_args="torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)",
1983input_size=(2, 4, 6, 6),
1984),
1985dict(
1986fullname="Conv2d_depthwise_strided",
1987constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
1988cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)",
1989input_size=(2, 4, 6, 6),
1990),
1991dict(
1992fullname="Conv2d_depthwise_padded",
1993constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
1994cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)",
1995input_size=(2, 4, 6, 6),
1996),
1997dict(
1998fullname="Conv2d_depthwise_dilated",
1999constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
2000cpp_constructor_args="torch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4)",
2001input_size=(2, 4, 5, 5),
2002),
2003dict(
2004module_name="MaxPool2d",
2005constructor_args=((3, 3), (2, 2), (1, 1)),
2006cpp_constructor_args="torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})",
2007input_size=(3, 7, 7),
2008desc="3d_input",
2009check_gradgrad=False,
2010),
2011dict(
2012module_name="MaxPool2d",
2013constructor_args=((3, 3), (2, 2), (1, 1)),
2014cpp_constructor_args="torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1})",
2015input_size=(1, 3, 7, 7),
2016check_with_channels_last=True,
2017desc="4d_input",
2018check_gradgrad=False,
2019),
2020dict(
2021module_name="AvgPool1d",
2022constructor_args=(2,),
2023cpp_constructor_args="torch::nn::AvgPool1dOptions(2)",
2024input_size=(2, 3, 6),
2025),
2026dict(
2027module_name="AvgPool1d",
2028constructor_args=((2,), (2,)),
2029cpp_constructor_args="torch::nn::AvgPool1dOptions(2).stride(2)",
2030input_size=(2, 3, 6),
2031desc="stride",
2032),
2033dict(
2034module_name="AvgPool1d",
2035constructor_args=(2, 2, 1),
2036cpp_constructor_args="torch::nn::AvgPool1dOptions(2).stride(2).padding(1)",
2037input_size=(2, 3, 6),
2038desc="stride_pad",
2039),
2040dict(
2041module_name="AvgPool2d",
2042constructor_args=((2, 2),),
2043cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2})",
2044input_size=(2, 3, 6, 6),
2045),
2046dict(
2047module_name="AvgPool2d",
2048constructor_args=((2, 2), (2, 2)),
2049cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2})",
2050input_size=(2, 3, 6, 6),
2051desc="stride",
2052),
2053dict(
2054module_name="AvgPool2d",
2055constructor_args=((2, 2), (2, 2), (1, 1)),
2056cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1})",
2057input_size=(2, 3, 6, 6),
2058desc="stride_pad",
2059),
2060dict(
2061fullname="AvgPool2d_divisor",
2062constructor=lambda: nn.AvgPool2d((2, 2), divisor_override=1),
2063cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).divisor_override(1)",
2064input_size=(2, 3, 6, 6),
2065check_with_long_tensor=True,
2066),
2067dict(
2068fullname="AvgPool2d_divisor_stride",
2069constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), divisor_override=1),
2070cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).divisor_override(1)",
2071input_size=(2, 3, 6, 6),
2072check_with_long_tensor=True,
2073),
2074dict(
2075fullname="AvgPool2d_divisor_stride_pad",
2076constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), (1, 1), divisor_override=1),
2077cpp_constructor_args="torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1}).divisor_override(1)",
2078input_size=(2, 3, 6, 6),
2079check_with_long_tensor=True,
2080),
2081dict(
2082module_name="LPPool2d",
2083constructor_args=(2, 2, 2),
2084cpp_constructor_args="torch::nn::LPPool2dOptions(2, 2).stride(2)",
2085input_size=(1, 3, 7, 7),
2086),
2087dict(
2088module_name="LPPool2d",
2089constructor_args=(1.5, 2),
2090cpp_constructor_args="torch::nn::LPPool2dOptions(1.5, 2)",
2091input_fn=lambda: torch.rand(1, 3, 7, 7),
2092desc="norm",
2093),
2094dict(
2095module_name="LPPool1d",
2096constructor_args=(1.5, 2),
2097cpp_constructor_args="torch::nn::LPPool1dOptions(1.5, 2)",
2098input_fn=lambda: torch.rand(1, 3, 7),
2099desc="norm",
2100),
2101dict(
2102module_name="LPPool1d",
2103constructor_args=(2, 2, 3),
2104cpp_constructor_args="torch::nn::LPPool1dOptions(2, 2).stride(3)",
2105input_size=(1, 3, 7),
2106),
2107dict(
2108module_name="LocalResponseNorm",
2109constructor_args=(3,),
2110cpp_constructor_args="torch::nn::LocalResponseNormOptions(3)",
2111input_size=(1, 5, 7),
2112desc="1d",
2113),
2114dict(
2115module_name="LocalResponseNorm",
2116constructor_args=(2,),
2117cpp_constructor_args="torch::nn::LocalResponseNormOptions(2)",
2118input_size=(1, 5, 7, 7),
2119desc="2d_uneven_pad",
2120),
2121dict(
2122module_name="LocalResponseNorm",
2123constructor_args=(1, 1.0, 0.5, 2.0),
2124cpp_constructor_args="torch::nn::LocalResponseNormOptions(1).alpha(1.).beta(0.5).k(2.)",
2125input_size=(1, 5, 7, 7, 7),
2126desc="3d_custom_params",
2127),
2128dict(
2129module_name="ReflectionPad1d",
2130constructor_args=((1, 2),),
2131cpp_constructor_args="torch::nn::ReflectionPad1dOptions({1, 2})",
2132input_size=(2, 3, 8),
2133),
2134dict(
2135module_name="ReflectionPad2d",
2136constructor_args=((1, 2, 3, 4),),
2137cpp_constructor_args="torch::nn::ReflectionPad2dOptions({1, 2, 3, 4})",
2138input_size=(2, 3, 8, 8),
2139),
2140dict(
2141module_name="ReplicationPad1d",
2142constructor_args=((1, 2),),
2143cpp_constructor_args="torch::nn::ReplicationPad1dOptions({1, 2})",
2144input_size=(2, 3, 4),
2145),
2146dict(
2147module_name="ReplicationPad2d",
2148constructor_args=((1, 2, 3, 4),),
2149cpp_constructor_args="torch::nn::ReplicationPad2dOptions({1, 2, 3, 4})",
2150input_size=(2, 3, 4, 4),
2151),
2152dict(
2153module_name="ZeroPad2d",
2154constructor_args=((1, 2, 3, 4),),
2155cpp_constructor_args="torch::nn::ZeroPad2dOptions({1, 2, 3, 4})",
2156input_size=(2, 3, 4, 4),
2157),
2158dict(
2159module_name="ZeroPad2d",
2160constructor_args=((-1, -1, -1, -2),),
2161cpp_constructor_args="torch::nn::ZeroPad2dOptions({-1, -1, -1, -2})",
2162input_size=(2, 3, 4, 4),
2163desc="negative_dims",
2164),
2165dict(
2166module_name="ConstantPad1d",
2167constructor_args=((1, 2), 2.0),
2168cpp_constructor_args="torch::nn::ConstantPad1dOptions({1, 2}, 2.)",
2169input_size=(2, 3, 4),
2170),
2171dict(
2172module_name="ConstantPad2d",
2173constructor_args=((1, 2, 3, 4), 2.0),
2174cpp_constructor_args="torch::nn::ConstantPad2dOptions({1, 2, 3, 4}, 2.)",
2175input_size=(2, 3, 4, 4),
2176),
2177dict(
2178module_name="ConstantPad3d",
2179constructor_args=((1, 2, 3, 4, 1, 0), 2.0),
2180cpp_constructor_args="torch::nn::ConstantPad3dOptions({1, 2, 3, 4, 1, 0}, 2.)",
2181input_size=(2, 3, 4, 4, 5),
2182),
2183dict(
2184module_name="Conv3d",
2185constructor_args=(3, 4, (2, 3, 4)),
2186cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, {2, 3, 4})",
2187input_size=(2, 3, 3, 4, 5),
2188cudnn=True,
2189check_with_long_tensor=True,
2190),
2191dict(
2192module_name="Conv3d",
2193constructor_args=(3, 4, (2, 3, 4), 1, 0, 1, 1, False),
2194cpp_constructor_args="""torch::nn::Conv3dOptions(3, 4, {2, 3, 4})
2195.stride(1).padding(0).dilation(1).groups(1).bias(false)""",
2196input_size=(2, 3, 3, 4, 5),
2197cudnn=True,
2198desc="no_bias",
2199check_with_long_tensor=True,
2200),
2201dict(
2202module_name="Conv3d",
2203constructor_args=(3, 4, 2, 2),
2204cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).stride(2)",
2205input_size=(2, 3, 5, 5, 5),
2206cudnn=True,
2207desc="stride",
2208check_with_long_tensor=True,
2209),
2210dict(
2211module_name="Conv3d",
2212constructor_args=(3, 4, 2, 2, 1),
2213cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)",
2214input_size=(2, 3, 5, 5, 5),
2215cudnn=True,
2216desc="stride_padding",
2217check_with_long_tensor=True,
2218),
2219dict(
2220module_name="Conv3d",
2221constructor_args=(3, 4, (2, 3, 4)),
2222cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, {2, 3, 4})",
2223input_size=(0, 3, 3, 4, 5),
2224cudnn=True,
2225check_with_long_tensor=True,
2226desc="zero_batch",
2227test_cuda=(not TEST_WITH_ROCM),
2228),
2229dict(
2230fullname="Conv3d_groups",
2231constructor=lambda: nn.Conv3d(4, 6, kernel_size=3, groups=2),
2232cpp_constructor_args="torch::nn::Conv3dOptions(4, 6, 3).groups(2)",
2233input_size=(2, 4, 4, 5, 4),
2234cudnn=True,
2235check_with_long_tensor=True,
2236),
2237dict(
2238fullname="Conv3d_dilated",
2239constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
2240cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).dilation(2)",
2241input_size=(2, 3, 5, 5, 5),
2242),
2243dict(
2244fullname="Conv3d_dilated_strided",
2245constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
2246cpp_constructor_args="torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)",
2247input_size=(2, 3, 5, 5, 5),
2248),
2249dict(
2250module_name="ConvTranspose3d",
2251constructor_args=(2, 3, (2, 3, 2)),
2252cpp_constructor_args="torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})",
2253cudnn=True,
2254input_size=(1, 2, 4, 5, 4),
2255),
2256dict(
2257module_name="ConvTranspose3d",
2258constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
2259cpp_constructor_args="""torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
2260.stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2})""",
2261cudnn=True,
2262input_size=(1, 2, 4, 5, 4),
2263desc="dilated",
2264),
2265dict(
2266module_name="MaxPool3d",
2267constructor_args=((2, 2, 2),),
2268cpp_constructor_args="torch::nn::MaxPool3dOptions({2, 2, 2})",
2269input_size=(2, 3, 5, 5, 5),
2270check_gradgrad=False,
2271),
2272dict(
2273module_name="MaxPool3d",
2274constructor_args=(2, (2, 2, 2)),
2275cpp_constructor_args="torch::nn::MaxPool3dOptions(2).stride({2, 2, 2})",
2276input_size=(2, 3, 5, 5, 5),
2277desc="stride",
2278check_gradgrad=False,
2279),
2280dict(
2281module_name="MaxPool3d",
2282constructor_args=(2, 2, (1, 1, 1)),
2283cpp_constructor_args="torch::nn::MaxPool3dOptions(2).stride(2).padding({1, 1, 1})",
2284input_size=(2, 3, 5, 5, 5),
2285desc="stride_padding",
2286check_gradgrad=False,
2287),
2288dict(
2289module_name="AvgPool3d",
2290constructor_args=((2, 2, 2),),
2291cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 2, 2})",
2292input_size=(2, 3, 4, 4, 4),
2293),
2294dict(
2295module_name="AvgPool3d",
2296constructor_args=(2, (2, 2, 2)),
2297cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride({2, 2, 2})",
2298input_size=(2, 3, 5, 5, 5),
2299desc="stride",
2300),
2301dict(
2302module_name="AvgPool3d",
2303constructor_args=(2, 2, (1, 1, 1)),
2304cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})",
2305input_size=(2, 3, 5, 5, 5),
2306desc="stride_pad",
2307),
2308dict(
2309module_name="AvgPool3d",
2310constructor_args=(4, 2, (1, 2, 1)),
2311cpp_constructor_args="torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1})",
2312input_size=(2, 3, 5, 5, 5),
2313desc="stride_pad_gpu_fixedkw_output",
2314),
2315dict(
2316module_name="AvgPool3d",
2317constructor_args=((2, 4, 8), 1, (1, 1, 2)),
2318cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2})",
2319input_size=(2, 3, 2, 4, 8),
2320desc="stride_pad_gpu_general_output",
2321),
2322dict(
2323module_name="AvgPool3d",
2324constructor_args=(3, 1, 0),
2325cpp_constructor_args="torch::nn::AvgPool3dOptions(3).stride(1).padding(0)",
2326input_size=(2, 3, 4, 4, 4),
2327desc="stride1_pad0_gpu_input",
2328),
2329dict(
2330module_name="AvgPool3d",
2331constructor_args=(2, 2, (1, 1, 1)),
2332cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1})",
2333input_size=(2, 3, 4, 4, 4),
2334desc="stride_pad_gpu_input_nooverlap",
2335),
2336dict(
2337fullname="AvgPool3d_divisor",
2338constructor=lambda: nn.AvgPool3d((2, 2, 2), divisor_override=1),
2339cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 2, 2}).divisor_override(1)",
2340input_size=(2, 3, 4, 4, 4),
2341check_with_long_tensor=True,
2342),
2343dict(
2344fullname="AvgPool3d_divisor_stride",
2345constructor=lambda: nn.AvgPool3d(2, (2, 2, 2), divisor_override=1),
2346cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride({2, 2, 2}).divisor_override(1)",
2347input_size=(2, 3, 5, 5, 5),
2348check_with_long_tensor=True,
2349),
2350dict(
2351fullname="AvgPool3d_divisor_stride_pad",
2352constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
2353cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)",
2354input_size=(2, 3, 5, 5, 5),
2355check_with_long_tensor=True,
2356),
2357dict(
2358fullname="AvgPool3d_divisor_stride_pad_gpu_fixedkw_output",
2359constructor=lambda: nn.AvgPool3d(4, 2, (1, 2, 1), divisor_override=1),
2360cpp_constructor_args="torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1}).divisor_override(1)",
2361input_size=(2, 3, 5, 5, 5),
2362check_with_long_tensor=True,
2363),
2364dict(
2365fullname="AvgPool3d_divisor_stride_pad_gpu_general_output",
2366constructor=lambda: nn.AvgPool3d((2, 4, 8), 1, (1, 1, 2), divisor_override=1),
2367cpp_constructor_args="torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2}).divisor_override(1)",
2368input_size=(2, 3, 2, 4, 8),
2369check_with_long_tensor=True,
2370),
2371dict(
2372fullname="AvgPool3d_divisor_stride1_pad0_gpu_input",
2373constructor=lambda: nn.AvgPool3d(3, 1, 0, divisor_override=1),
2374cpp_constructor_args="torch::nn::AvgPool3dOptions(3).stride(1).padding(0).divisor_override(1)",
2375input_size=(2, 3, 4, 4, 4),
2376check_with_long_tensor=True,
2377),
2378dict(
2379fullname="AvgPool3d_divisor_stride_pad_gpu_input_nooverlap",
2380constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1),
2381cpp_constructor_args="torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1)",
2382input_size=(2, 3, 4, 4, 4),
2383check_with_long_tensor=True,
2384),
2385dict(
2386module_name="ReplicationPad3d",
2387constructor_args=((1, 2, 3, 4, 5, 6),),
2388cpp_constructor_args="torch::nn::ReplicationPad3dOptions({1, 2, 3, 4, 5, 6})",
2389input_size=(2, 3, 5, 5, 5),
2390),
2391dict(
2392module_name="Embedding",
2393constructor_args=(4, 3),
2394cpp_constructor_args="torch::nn::EmbeddingOptions(4, 3)",
2395input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2396jacobian_input=False,
2397check_gradgrad=False,
2398),
2399dict(
2400module_name="EmbeddingBag",
2401constructor_args=(4, 3),
2402cpp_constructor_args="torch::nn::EmbeddingBagOptions(4, 3)",
2403input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2404jacobian_input=False,
2405check_gradgrad=False,
2406check_forward_only=True,
2407desc="mean",
2408),
2409dict(
2410module_name="EmbeddingBag",
2411constructor_args=(4, 3, None, 2.0, False, "sum"),
2412cpp_constructor_args="""torch::nn::EmbeddingBagOptions(4, 3)
2413.max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)""",
2414input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2415jacobian_input=False,
2416check_gradgrad=False,
2417check_forward_only=True,
2418desc="sum",
2419),
2420dict(
2421module_name="EmbeddingBag",
2422constructor_args=(4, 3, None, 2.0, False, "max"),
2423cpp_constructor_args="""torch::nn::EmbeddingBagOptions(4, 3)
2424.max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)""",
2425input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
2426jacobian_input=False,
2427check_gradgrad=False,
2428check_forward_only=True,
2429desc="max",
2430),
2431dict(
2432fullname="EmbeddingBag_sparse",
2433constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
2434cpp_constructor_args="torch::nn::EmbeddingBagOptions(4, 3).sparse(true)",
2435input_fn=lambda: torch.randperm(2).repeat(1, 2),
2436jacobian_input=False,
2437check_gradgrad=False,
2438),
2439dict(
2440constructor=lambda: nn.Embedding(4, 3, sparse=True),
2441cpp_constructor_args="torch::nn::EmbeddingOptions(4, 3).sparse(true)",
2442input_fn=lambda: torch.randperm(2).repeat(1, 2),
2443jacobian_input=False,
2444fullname="Embedding_sparse",
2445check_gradgrad=False,
2446),
2447dict(
2448module_name="PixelShuffle",
2449constructor_args=(3,),
2450cpp_constructor_args="torch::nn::PixelShuffleOptions(3)",
2451input_size=(1, 9, 4, 4),
2452),
2453dict(
2454constructor=wrap_functional(
2455F.interpolate, size=12, scale_factor=None, mode="nearest"
2456),
2457cpp_options_args="""F::InterpolateFuncOptions()
2458.size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2459input_size=(1, 2, 4),
2460fullname="interpolate_nearest_1d",
2461pickle=False,
2462),
2463dict(
2464constructor=wrap_functional(
2465F.interpolate, size=12, scale_factor=None, mode="nearest"
2466),
2467cpp_options_args="""F::InterpolateFuncOptions()
2468.size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2469input_size=(0, 2, 4),
2470fullname="interpolate_nearest_1d_zero_dim",
2471pickle=False,
2472),
2473dict(
2474constructor=wrap_functional(
2475F.interpolate, size=(12,), scale_factor=None, mode="nearest"
2476),
2477cpp_options_args="""F::InterpolateFuncOptions()
2478.size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest)""",
2479input_size=(1, 2, 3),
2480fullname="interpolate_nearest_tuple_1d",
2481pickle=False,
2482),
2483dict(
2484constructor=wrap_functional(
2485F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2486),
2487cpp_options_args="""F::InterpolateFuncOptions()
2488.size(c10::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)""",
2489input_size=(1, 2, 4),
2490fullname="interpolate_nearest_scale_1d",
2491pickle=False,
2492),
2493dict(
2494constructor=wrap_functional(
2495F.interpolate,
2496size=12,
2497scale_factor=None,
2498mode="linear",
2499align_corners=False,
2500),
2501cpp_options_args="""F::InterpolateFuncOptions()
2502.size(std::vector<int64_t>({12}))
2503.scale_factor(c10::nullopt)
2504.mode(torch::kLinear)
2505.align_corners(false)""",
2506input_size=(1, 2, 4),
2507fullname="interpolate_linear_1d",
2508pickle=False,
2509),
2510dict(
2511constructor=wrap_functional(
2512F.interpolate,
2513size=(4,),
2514scale_factor=None,
2515mode="linear",
2516align_corners=False,
2517),
2518cpp_options_args="""F::InterpolateFuncOptions()
2519.size(std::vector<int64_t>({4}))
2520.scale_factor(c10::nullopt)
2521.mode(torch::kLinear)
2522.align_corners(false)""",
2523input_size=(1, 2, 3),
2524fullname="interpolate_linear_tuple_1d",
2525pickle=False,
2526),
2527dict(
2528constructor=wrap_functional(
2529F.interpolate,
2530size=None,
2531scale_factor=4.0,
2532mode="linear",
2533align_corners=False,
2534),
2535cpp_options_args="""F::InterpolateFuncOptions()
2536.size(c10::nullopt)
2537.scale_factor(std::vector<double>({4.}))
2538.mode(torch::kLinear)
2539.align_corners(false)""",
2540input_size=(1, 2, 4),
2541fullname="interpolate_linear_scale_1d",
2542pickle=False,
2543),
2544dict(
2545constructor=wrap_functional(
2546F.interpolate,
2547size=12,
2548scale_factor=None,
2549mode="linear",
2550align_corners=False,
2551),
2552cpp_options_args="""F::InterpolateFuncOptions()
2553.size(std::vector<int64_t>({12}))
2554.scale_factor(c10::nullopt)
2555.mode(torch::kLinear)
2556.align_corners(false)""",
2557input_size=(0, 2, 4),
2558fullname="interpolate_linear_1d_zero_dim",
2559pickle=False,
2560),
2561dict(
2562constructor=wrap_functional(
2563F.interpolate, size=12, scale_factor=None, mode="linear", align_corners=True
2564),
2565cpp_options_args="""F::InterpolateFuncOptions()
2566.size(std::vector<int64_t>({12}))
2567.scale_factor(c10::nullopt)
2568.mode(torch::kLinear)
2569.align_corners(true)""",
2570input_size=(1, 2, 4),
2571fullname="interpolate_linear_1d_align_corners",
2572pickle=False,
2573),
2574dict(
2575constructor=wrap_functional(
2576F.interpolate,
2577size=None,
2578scale_factor=4.0,
2579mode="linear",
2580align_corners=True,
2581),
2582cpp_options_args="""F::InterpolateFuncOptions()
2583.size(c10::nullopt)
2584.scale_factor(std::vector<double>({4.}))
2585.mode(torch::kLinear)
2586.align_corners(true)""",
2587input_size=(1, 2, 4),
2588fullname="interpolate_linear_scale_1d_align_corners",
2589pickle=False,
2590),
2591dict(
2592constructor=wrap_functional(
2593F.interpolate, size=2, scale_factor=None, mode="nearest"
2594),
2595cpp_options_args="""F::InterpolateFuncOptions()
2596.size(std::vector<int64_t>({2, 2}))
2597.scale_factor(c10::nullopt)
2598.mode(torch::kNearest)""",
2599input_size=(1, 128, 1, 1),
2600fullname="interpolate_nearest_2d_launch_configs",
2601pickle=False,
2602),
2603dict(
2604constructor=wrap_functional(
2605F.interpolate, size=12, scale_factor=None, mode="nearest"
2606),
2607cpp_options_args="""F::InterpolateFuncOptions()
2608.size(std::vector<int64_t>({12, 12}))
2609.scale_factor(c10::nullopt)
2610.mode(torch::kNearest)""",
2611input_size=(1, 2, 4, 4),
2612fullname="interpolate_nearest_2d",
2613pickle=False,
2614),
2615dict(
2616constructor=wrap_functional(
2617F.interpolate, size=(12, 16), scale_factor=None, mode="nearest"
2618),
2619cpp_options_args="""F::InterpolateFuncOptions()
2620.size(std::vector<int64_t>({12, 16}))
2621.scale_factor(c10::nullopt)
2622.mode(torch::kNearest)""",
2623input_size=(1, 2, 3, 4),
2624fullname="interpolate_nearest_tuple_2d",
2625pickle=False,
2626),
2627dict(
2628constructor=wrap_functional(
2629F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2630),
2631cpp_options_args="""F::InterpolateFuncOptions()
2632.size(c10::nullopt)
2633.scale_factor(std::vector<double>({4., 4.}))
2634.mode(torch::kNearest)""",
2635input_size=(1, 2, 4, 4),
2636fullname="interpolate_nearest_scale_2d",
2637pickle=False,
2638),
2639dict(
2640constructor=wrap_functional(
2641F.interpolate, size=12, scale_factor=None, mode="nearest"
2642),
2643cpp_options_args="""F::InterpolateFuncOptions()
2644.size(std::vector<int64_t>({12, 12}))
2645.scale_factor(c10::nullopt)
2646.mode(torch::kNearest)""",
2647input_size=(0, 2, 4, 4),
2648fullname="interpolate_nearest_2d_zero_dim",
2649pickle=False,
2650),
2651dict(
2652constructor=wrap_functional(
2653F.interpolate,
2654size=12,
2655scale_factor=None,
2656mode="bilinear",
2657align_corners=False,
2658),
2659cpp_options_args="""F::InterpolateFuncOptions()
2660.size(std::vector<int64_t>({12, 12}))
2661.scale_factor(c10::nullopt)
2662.mode(torch::kBilinear)
2663.align_corners(false)""",
2664input_size=(1, 2, 4, 4),
2665fullname="interpolate_bilinear_2d",
2666pickle=False,
2667),
2668dict(
2669constructor=wrap_functional(
2670F.interpolate,
2671size=12,
2672scale_factor=None,
2673mode="bilinear",
2674align_corners=False,
2675),
2676cpp_options_args="""F::InterpolateFuncOptions()
2677.size(std::vector<int64_t>({12, 12}))
2678.scale_factor(c10::nullopt)
2679.mode(torch::kBilinear)
2680.align_corners(false)""",
2681input_size=(0, 2, 4, 4),
2682fullname="interpolate_bilinear_2d_zero_dim",
2683pickle=False,
2684),
2685dict(
2686constructor=wrap_functional(
2687F.interpolate,
2688size=(4, 6),
2689scale_factor=None,
2690mode="bilinear",
2691align_corners=False,
2692),
2693cpp_options_args="""F::InterpolateFuncOptions()
2694.size(std::vector<int64_t>({4, 6}))
2695.scale_factor(c10::nullopt)
2696.mode(torch::kBilinear)
2697.align_corners(false)""",
2698input_size=(1, 2, 2, 3),
2699fullname="interpolate_bilinear_tuple_2d",
2700pickle=False,
2701),
2702dict(
2703constructor=wrap_functional(
2704F.interpolate,
2705size=None,
2706scale_factor=4.0,
2707mode="bilinear",
2708align_corners=False,
2709),
2710cpp_options_args="""F::InterpolateFuncOptions()
2711.size(c10::nullopt)
2712.scale_factor(std::vector<double>({4., 4.}))
2713.mode(torch::kBilinear)
2714.align_corners(false)""",
2715input_size=(1, 2, 4, 4),
2716fullname="interpolate_bilinear_scale_2d",
2717pickle=False,
2718),
2719dict(
2720constructor=wrap_functional(
2721F.interpolate,
2722size=None,
2723scale_factor=(2.0, 2.0),
2724mode="bilinear",
2725align_corners=False,
2726),
2727cpp_options_args="""F::InterpolateFuncOptions()
2728.size(c10::nullopt)
2729.scale_factor(std::vector<double>({2., 2.}))
2730.mode(torch::kBilinear)
2731.align_corners(false)""",
2732input_size=(1, 2, 4, 4),
2733fullname="interpolate_bilinear_scale_tuple_shared_2d",
2734pickle=False,
2735),
2736dict(
2737constructor=wrap_functional(
2738F.interpolate,
2739size=None,
2740scale_factor=(2.0, 1.0),
2741mode="bilinear",
2742align_corners=False,
2743),
2744cpp_options_args="""F::InterpolateFuncOptions()
2745.size(c10::nullopt)
2746.scale_factor(std::vector<double>({2., 1.}))
2747.mode(torch::kBilinear)
2748.align_corners(false)""",
2749input_size=(1, 2, 4, 4),
2750fullname="interpolate_bilinear_scale_tuple_skewed_2d",
2751pickle=False,
2752),
2753dict(
2754constructor=wrap_functional(
2755F.interpolate,
2756size=(4, 6),
2757scale_factor=None,
2758mode="bilinear",
2759align_corners=True,
2760),
2761cpp_options_args="""F::InterpolateFuncOptions()
2762.size(std::vector<int64_t>({4, 6}))
2763.scale_factor(c10::nullopt)
2764.mode(torch::kBilinear)
2765.align_corners(true)""",
2766input_size=(1, 2, 4, 4),
2767fullname="interpolate_bilinear_tuple_2d_align_corners",
2768pickle=False,
2769),
2770dict(
2771constructor=wrap_functional(
2772F.interpolate,
2773size=None,
2774scale_factor=(2.0, 1.0),
2775mode="bilinear",
2776align_corners=True,
2777),
2778cpp_options_args="""F::InterpolateFuncOptions()
2779.size(c10::nullopt)
2780.scale_factor(std::vector<double>({2., 1.}))
2781.mode(torch::kBilinear)
2782.align_corners(true)""",
2783input_size=(1, 2, 4, 4),
2784fullname="interpolate_bilinear_scale_tuple_skewed_2d_align_corners",
2785pickle=False,
2786),
2787dict(
2788constructor=wrap_functional(
2789F.interpolate,
2790size=12,
2791scale_factor=None,
2792mode="bicubic",
2793align_corners=False,
2794),
2795cpp_options_args="""F::InterpolateFuncOptions()
2796.size(std::vector<int64_t>({12, 12}))
2797.scale_factor(c10::nullopt)
2798.mode(torch::kBicubic)
2799.align_corners(false)""",
2800input_size=(1, 2, 4, 4),
2801fullname="interpolate_bicubic_2d",
2802pickle=False,
2803),
2804dict(
2805constructor=wrap_functional(
2806F.interpolate,
2807size=12,
2808scale_factor=None,
2809mode="bicubic",
2810align_corners=False,
2811),
2812cpp_options_args="""F::InterpolateFuncOptions()
2813.size(std::vector<int64_t>({12, 12}))
2814.scale_factor(c10::nullopt)
2815.mode(torch::kBicubic)
2816.align_corners(false)""",
2817input_size=(0, 2, 4, 4),
2818fullname="interpolate_bicubic_2d_zero_dim",
2819pickle=False,
2820),
2821dict(
2822constructor=wrap_functional(
2823F.interpolate,
2824size=(4, 6),
2825scale_factor=None,
2826mode="bicubic",
2827align_corners=False,
2828),
2829cpp_options_args="""F::InterpolateFuncOptions()
2830.size(std::vector<int64_t>({4, 6}))
2831.scale_factor(c10::nullopt)
2832.mode(torch::kBicubic)
2833.align_corners(false)""",
2834input_size=(1, 2, 2, 3),
2835fullname="interpolate_bicubic_tuple_2d",
2836pickle=False,
2837),
2838dict(
2839constructor=wrap_functional(
2840F.interpolate,
2841size=None,
2842scale_factor=4.0,
2843mode="bicubic",
2844align_corners=False,
2845),
2846cpp_options_args="""F::InterpolateFuncOptions()
2847.size(c10::nullopt)
2848.scale_factor(std::vector<double>({4., 4.}))
2849.mode(torch::kBicubic)
2850.align_corners(false)""",
2851input_size=(1, 2, 4, 4),
2852fullname="interpolate_bicubic_scale_2d",
2853pickle=False,
2854),
2855dict(
2856constructor=wrap_functional(
2857F.interpolate,
2858size=None,
2859scale_factor=(2.0, 2.0),
2860mode="bicubic",
2861align_corners=False,
2862),
2863cpp_options_args="""F::InterpolateFuncOptions()
2864.size(c10::nullopt)
2865.scale_factor(std::vector<double>({2., 2.}))
2866.mode(torch::kBicubic)
2867.align_corners(false)""",
2868input_size=(1, 2, 4, 4),
2869fullname="interpolate_bicubic_scale_tuple_shared_2d",
2870pickle=False,
2871),
2872dict(
2873constructor=wrap_functional(
2874F.interpolate,
2875size=None,
2876scale_factor=(2.0, 1.0),
2877mode="bicubic",
2878align_corners=False,
2879),
2880cpp_options_args="""F::InterpolateFuncOptions()
2881.size(c10::nullopt)
2882.scale_factor(std::vector<double>({2., 1.}))
2883.mode(torch::kBicubic)
2884.align_corners(false)""",
2885input_size=(1, 2, 4, 4),
2886fullname="interpolate_bicubic_scale_tuple_skewed_2d",
2887pickle=False,
2888),
2889dict(
2890constructor=wrap_functional(
2891F.interpolate,
2892size=(4, 6),
2893scale_factor=None,
2894mode="bicubic",
2895align_corners=True,
2896),
2897cpp_options_args="""F::InterpolateFuncOptions()
2898.size(std::vector<int64_t>({4, 6}))
2899.scale_factor(c10::nullopt)
2900.mode(torch::kBicubic)
2901.align_corners(true)""",
2902input_size=(1, 2, 4, 4),
2903fullname="interpolate_bicubic_tuple_2d_align_corners",
2904pickle=False,
2905),
2906dict(
2907constructor=wrap_functional(
2908F.interpolate,
2909size=None,
2910scale_factor=(2.0, 1.0),
2911mode="bicubic",
2912align_corners=True,
2913),
2914cpp_options_args="""F::InterpolateFuncOptions()
2915.size(c10::nullopt)
2916.scale_factor(std::vector<double>({2., 1.}))
2917.mode(torch::kBicubic)
2918.align_corners(true)""",
2919input_size=(1, 2, 4, 4),
2920fullname="interpolate_bicubic_scale_tuple_skewed_2d_align_corners",
2921pickle=False,
2922),
2923dict(
2924constructor=wrap_functional(
2925F.interpolate, size=12, scale_factor=None, mode="nearest"
2926),
2927cpp_options_args="""F::InterpolateFuncOptions()
2928.size(std::vector<int64_t>({12, 12, 12}))
2929.scale_factor(c10::nullopt)
2930.mode(torch::kNearest)""",
2931input_size=(1, 2, 4, 4, 4),
2932fullname="interpolate_nearest_3d",
2933pickle=False,
2934),
2935dict(
2936constructor=wrap_functional(
2937F.interpolate, size=12, scale_factor=None, mode="nearest"
2938),
2939cpp_options_args="""F::InterpolateFuncOptions()
2940.size(std::vector<int64_t>({12, 12, 12}))
2941.scale_factor(c10::nullopt)
2942.mode(torch::kNearest)""",
2943input_size=(0, 2, 4, 4, 4),
2944fullname="interpolate_nearest_3d_zero_dim",
2945pickle=False,
2946),
2947dict(
2948constructor=wrap_functional(
2949F.interpolate, size=(12, 16, 16), scale_factor=None, mode="nearest"
2950),
2951cpp_options_args="""F::InterpolateFuncOptions()
2952.size(std::vector<int64_t>({12, 16, 16}))
2953.scale_factor(c10::nullopt)
2954.mode(torch::kNearest)""",
2955input_size=(1, 2, 3, 4, 4),
2956fullname="interpolate_nearest_tuple_3d",
2957pickle=False,
2958),
2959dict(
2960constructor=wrap_functional(
2961F.interpolate, size=None, scale_factor=4.0, mode="nearest"
2962),
2963cpp_options_args="""F::InterpolateFuncOptions()
2964.size(c10::nullopt)
2965.scale_factor(std::vector<double>({4., 4., 4.}))
2966.mode(torch::kNearest)""",
2967input_size=(1, 2, 4, 4, 4),
2968fullname="interpolate_nearest_scale_3d",
2969pickle=False,
2970),
2971dict(
2972constructor=wrap_functional(
2973F.interpolate,
2974size=12,
2975scale_factor=None,
2976mode="trilinear",
2977align_corners=False,
2978),
2979cpp_options_args="""F::InterpolateFuncOptions()
2980.size(std::vector<int64_t>({12, 12, 12}))
2981.scale_factor(c10::nullopt)
2982.mode(torch::kTrilinear)
2983.align_corners(false)""",
2984input_size=(1, 2, 4, 4, 4),
2985fullname="interpolate_trilinear_3d",
2986pickle=False,
2987),
2988dict(
2989constructor=wrap_functional(
2990F.interpolate,
2991size=12,
2992scale_factor=None,
2993mode="trilinear",
2994align_corners=False,
2995),
2996cpp_options_args="""F::InterpolateFuncOptions()
2997.size(std::vector<int64_t>({12, 12, 12}))
2998.scale_factor(c10::nullopt)
2999.mode(torch::kTrilinear)
3000.align_corners(false)""",
3001input_size=(0, 2, 4, 4, 4),
3002fullname="interpolate_trilinear_3d_zero_dim",
3003pickle=False,
3004),
3005dict(
3006constructor=wrap_functional(
3007F.interpolate,
3008size=(4, 6, 6),
3009scale_factor=None,
3010mode="trilinear",
3011align_corners=False,
3012),
3013cpp_options_args="""F::InterpolateFuncOptions()
3014.size(std::vector<int64_t>({4, 6, 6}))
3015.scale_factor(c10::nullopt)
3016.mode(torch::kTrilinear)
3017.align_corners(false)""",
3018input_size=(1, 2, 2, 3, 3),
3019fullname="interpolate_trilinear_tuple_3d",
3020pickle=False,
3021),
3022dict(
3023constructor=wrap_functional(
3024F.interpolate,
3025size=None,
3026scale_factor=3.0,
3027mode="trilinear",
3028align_corners=False,
3029),
3030cpp_options_args="""F::InterpolateFuncOptions()
3031.size(c10::nullopt)
3032.scale_factor(std::vector<double>({3., 3., 3.}))
3033.mode(torch::kTrilinear)
3034.align_corners(false)""",
3035input_size=(1, 2, 3, 4, 4),
3036fullname="interpolate_trilinear_scale_3d",
3037# See https://github.com/pytorch/pytorch/issues/5006
3038precision=3e-4,
3039pickle=False,
3040),
3041dict(
3042constructor=wrap_functional(
3043F.interpolate,
3044size=(4, 6, 6),
3045scale_factor=None,
3046mode="trilinear",
3047align_corners=True,
3048),
3049cpp_options_args="""F::InterpolateFuncOptions()
3050.size(std::vector<int64_t>({4, 6, 6}))
3051.scale_factor(c10::nullopt)
3052.mode(torch::kTrilinear)
3053.align_corners(true)""",
3054input_size=(1, 2, 2, 3, 3),
3055fullname="interpolate_trilinear_tuple_3d_align_corners",
3056pickle=False,
3057),
3058dict(
3059constructor=wrap_functional(
3060F.interpolate,
3061size=None,
3062scale_factor=3.0,
3063mode="trilinear",
3064align_corners=True,
3065),
3066cpp_options_args="""F::InterpolateFuncOptions()
3067.size(c10::nullopt)
3068.scale_factor(std::vector<double>({3., 3., 3.}))
3069.mode(torch::kTrilinear)
3070.align_corners(true)""",
3071input_size=(1, 2, 3, 4, 4),
3072fullname="interpolate_trilinear_scale_3d_align_corners",
3073# See https://github.com/pytorch/pytorch/issues/5006
3074precision=3e-4,
3075pickle=False,
3076),
3077dict(
3078module_name="AdaptiveMaxPool1d",
3079constructor_args=(3,),
3080cpp_constructor_args="torch::nn::AdaptiveMaxPool1dOptions(3)",
3081input_fn=lambda: _rand_tensor_non_equal(1, 3, 5),
3082),
3083dict(
3084module_name="AdaptiveMaxPool2d",
3085constructor_args=(3,),
3086cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions(3)",
3087input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3088desc="single",
3089),
3090dict(
3091module_name="AdaptiveMaxPool2d",
3092constructor_args=((3, 4),),
3093cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions({3, 4})",
3094input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3095desc="tuple",
3096),
3097dict(
3098module_name="AdaptiveMaxPool2d",
3099constructor_args=((3, None),),
3100cpp_constructor_args="torch::nn::AdaptiveMaxPool2dOptions({3, c10::nullopt})",
3101input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
3102desc="tuple_none",
3103),
3104dict(
3105module_name="AdaptiveMaxPool3d",
3106constructor_args=(3,),
3107cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions(3)",
3108input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3109desc="single",
3110),
3111dict(
3112module_name="AdaptiveMaxPool3d",
3113constructor_args=((3, 4, 5),),
3114cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})",
3115input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3116desc="tuple",
3117),
3118dict(
3119module_name="AdaptiveMaxPool3d",
3120constructor_args=((3, None, 5),),
3121cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, c10::nullopt, 5})",
3122input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
3123desc="tuple_none",
3124),
3125dict(
3126module_name="AdaptiveMaxPool3d",
3127constructor_args=(3,),
3128cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions(3)",
3129input_fn=lambda: _rand_tensor_non_equal(2, 3, 12, 9, 3),
3130desc="single_nonatomic",
3131),
3132dict(
3133module_name="AdaptiveMaxPool3d",
3134constructor_args=((3, 4, 5),),
3135cpp_constructor_args="torch::nn::AdaptiveMaxPool3dOptions({3, 4, 5})",
3136input_fn=lambda: _rand_tensor_non_equal(2, 3, 6, 4, 10),
3137desc="tuple_nonatomic",
3138),
3139dict(
3140module_name="AdaptiveAvgPool1d",
3141constructor_args=(3,),
3142cpp_constructor_args="torch::nn::AdaptiveAvgPool1dOptions(3)",
3143input_fn=lambda: torch.rand(1, 3, 5),
3144),
3145dict(
3146module_name="AdaptiveAvgPool1d",
3147constructor_args=(1,),
3148cpp_constructor_args="torch::nn::AdaptiveAvgPool1dOptions(1)",
3149input_fn=lambda: torch.rand(1, 3, 5),
3150desc="one_output",
3151),
3152dict(
3153module_name="AdaptiveAvgPool2d",
3154constructor_args=(3,),
3155cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions(3)",
3156input_fn=lambda: torch.rand(1, 3, 5, 6),
3157desc="single",
3158check_gradgrad=False,
3159),
3160dict(
3161module_name="AdaptiveAvgPool2d",
3162constructor_args=(1,),
3163cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions(1)",
3164input_fn=lambda: torch.rand(1, 3, 5, 6),
3165desc="single_1x1output",
3166check_gradgrad=False,
3167),
3168dict(
3169module_name="AdaptiveAvgPool2d",
3170constructor_args=((3, 4),),
3171cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions({3, 4})",
3172input_fn=lambda: torch.rand(1, 3, 5, 6),
3173desc="tuple",
3174check_gradgrad=False,
3175),
3176dict(
3177module_name="AdaptiveAvgPool2d",
3178constructor_args=((3, None),),
3179cpp_constructor_args="torch::nn::AdaptiveAvgPool2dOptions({3, c10::nullopt})",
3180input_fn=lambda: torch.rand(1, 3, 5, 6),
3181desc="tuple_none",
3182check_gradgrad=False,
3183),
3184dict(
3185module_name="AdaptiveAvgPool3d",
3186constructor_args=(3,),
3187cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions(3)",
3188input_fn=lambda: torch.rand(2, 3, 5, 2, 7),
3189desc="single",
3190),
3191dict(
3192module_name="AdaptiveAvgPool3d",
3193constructor_args=((3, 4, 5),),
3194cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions({3, 4, 5})",
3195input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
3196desc="tuple",
3197),
3198dict(
3199module_name="AdaptiveAvgPool3d",
3200constructor_args=((None, 4, 5),),
3201cpp_constructor_args="torch::nn::AdaptiveAvgPool3dOptions({c10::nullopt, 4, 5})",
3202input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
3203desc="tuple_none",
3204),
3205dict(module_name="SELU", input_size=(3, 2, 5), check_inplace=True),
3206dict(module_name="SELU", input_size=(), check_inplace=True, desc="scalar"),
3207dict(
3208module_name="CELU",
3209input_size=(3, 2, 5),
3210constructor_args=(2.0,),
3211cpp_constructor_args="torch::nn::CELUOptions().alpha(2.)",
3212check_inplace=True,
3213reference_fn=lambda x, *_: torch.where(x >= 0, x, 2.0 * ((0.5 * x).exp() - 1)),
3214),
3215dict(
3216module_name="CELU",
3217input_size=(),
3218constructor_args=(2.0,),
3219cpp_constructor_args="torch::nn::CELUOptions().alpha(2.)",
3220check_inplace=True,
3221reference_fn=lambda x, *_: torch.where(x >= 0, x, 2.0 * ((0.5 * x).exp() - 1)),
3222desc="scalar",
3223),
3224dict(
3225module_name="GLU",
3226input_size=(5, 6),
3227),
3228dict(
3229module_name="GLU",
3230constructor_args=(1,),
3231cpp_constructor_args="torch::nn::GLUOptions(1)",
3232input_size=(5, 6, 7),
3233desc="dim",
3234),
3235dict(
3236module_name="GELU",
3237input_size=(),
3238desc="scalar",
3239reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
3240),
3241dict(
3242module_name="GELU",
3243input_size=(3, 2, 5),
3244reference_fn=lambda x, *_: x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))),
3245),
3246dict(
3247constructor=wrap_functional(F.softmax, dim=-1),
3248cpp_options_args="F::SoftmaxFuncOptions(-1)",
3249input_size=(2, 128), # trigger the last-dim algo in CUDA
3250fullname="softmax_lastdim",
3251pickle=False,
3252),
3253dict(
3254constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
3255cpp_options_args="F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)",
3256input_size=(2, 128),
3257fullname="softmax_lastdim_dtype",
3258pickle=False,
3259test_cuda=False,
3260),
3261dict(
3262constructor=wrap_functional(F.softmax, dim=1),
3263cpp_options_args="F::SoftmaxFuncOptions(1)",
3264input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
3265fullname="softmax_spatial_special",
3266pickle=False,
3267test_cuda=(not TEST_WITH_ROCM),
3268),
3269dict(
3270constructor=wrap_functional(F.softmax, dim=1),
3271cpp_options_args="F::SoftmaxFuncOptions(1)",
3272input_size=(2, 2, 4, 4), # regular spatial algorithm
3273fullname="softmax_spatial",
3274pickle=False,
3275),
3276dict(
3277constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
3278cpp_options_args="F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)",
3279input_size=(2, 2, 4, 4), # regular spatial algorithm
3280fullname="softmax_spatial_dtype",
3281pickle=False,
3282test_cuda=False,
3283),
3284dict(
3285constructor=wrap_functional(F.softmax, dim=0),
3286cpp_options_args="F::SoftmaxFuncOptions(0)",
3287input_size=(2, 3, 4, 5),
3288fullname="softmax_functional_dim0",
3289test_cuda=False,
3290pickle=False,
3291),
3292dict(
3293constructor=wrap_functional(F.softmax, dim=3),
3294cpp_options_args="F::SoftmaxFuncOptions(3)",
3295input_size=(2, 3, 4, 5),
3296fullname="softmax_functional_dim3",
3297test_cuda=False,
3298pickle=False,
3299),
3300dict(
3301constructor=wrap_functional(F.softmax, dim=-1),
3302cpp_options_args="F::SoftmaxFuncOptions(-1)",
3303input_size=(),
3304fullname="softmax_functional_scalar",
3305test_cuda=False,
3306pickle=False,
3307),
3308dict(
3309constructor=wrap_functional(F.log_softmax, dim=-1),
3310cpp_options_args="F::LogSoftmaxFuncOptions(-1)",
3311input_size=(2, 128), # trigger the last-dim algo in CUDA
3312fullname="log_softmax_lastdim",
3313pickle=False,
3314),
3315dict(
3316constructor=wrap_functional(F.log_softmax, dim=1),
3317cpp_options_args="F::LogSoftmaxFuncOptions(1)",
3318input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
3319fullname="log_softmax_spatial_special",
3320pickle=False,
3321test_cuda=(not TEST_WITH_ROCM),
3322),
3323dict(
3324constructor=wrap_functional(F.log_softmax, dim=1),
3325cpp_options_args="F::LogSoftmaxFuncOptions(1)",
3326input_size=(2, 2, 4, 4), # regular spatial algorithm
3327fullname="log_softmax_spatial",
3328pickle=False,
3329),
3330dict(
3331constructor=wrap_functional(F.log_softmax, dim=0),
3332cpp_options_args="F::LogSoftmaxFuncOptions(0)",
3333input_size=(2, 3, 4, 5),
3334fullname="log_softmax_dim0",
3335pickle=False,
3336),
3337dict(
3338constructor=wrap_functional(F.log_softmax, dim=3),
3339cpp_options_args="F::LogSoftmaxFuncOptions(3)",
3340input_size=(2, 3, 4, 5),
3341fullname="log_softmax_dim3",
3342pickle=False,
3343),
3344dict(
3345constructor=wrap_functional(F.log_softmax, dim=0),
3346cpp_options_args="F::LogSoftmaxFuncOptions(0)",
3347input_size=(),
3348fullname="log_softmax_scalar",
3349pickle=False,
3350),
3351dict(
3352fullname="Unfold",
3353constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
3354cpp_constructor_args="torch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})",
3355input_size=(2, 4, 3, 3),
3356check_gradgrad=False,
3357test_cuda=True,
3358),
3359dict(
3360fullname="Fold",
3361constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
3362cpp_constructor_args="torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1})",
3363input_size=(2, 16, 4),
3364check_gradgrad=False,
3365test_cuda=True,
3366),
3367dict(
3368fullname="Unfold_int_input",
3369constructor=lambda: nn.Unfold(2, 1, 0, 1),
3370cpp_constructor_args="torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)",
3371input_size=(2, 4, 3, 3),
3372check_gradgrad=False,
3373test_cuda=True,
3374),
3375dict(
3376fullname="Fold_int_input",
3377constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
3378cpp_constructor_args="torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)",
3379input_size=(2, 16, 4),
3380check_gradgrad=False,
3381test_cuda=True,
3382),
3383dict(
3384module_name="Threshold",
3385constructor_args=(2.0, 1.0),
3386cpp_constructor_args="torch::nn::ThresholdOptions(2., 1.)",
3387input_size=(),
3388check_inplace=True,
3389desc="threshold_value_scalar",
3390),
3391dict(module_name="ReLU", input_size=(), check_inplace=True, desc="scalar"),
3392dict(module_name="ReLU6", input_size=(), check_inplace=True, desc="scalar"),
3393dict(
3394module_name="RReLU",
3395constructor_args=(0.1, 0.9),
3396cpp_constructor_args="torch::nn::RReLUOptions().lower(0.1).upper(0.9)",
3397input_size=(),
3398desc="with_up_down_scalar",
3399test_cuda=False,
3400),
3401dict(
3402module_name="Hardtanh",
3403input_size=(),
3404reference_fn=lambda i, *_: i.clamp(-1, 1),
3405desc="scalar",
3406),
3407dict(
3408module_name="Sigmoid",
3409input_size=(),
3410desc="scalar",
3411),
3412dict(
3413module_name="Tanh",
3414input_size=(),
3415desc="scalar",
3416),
3417dict(
3418module_name="Softmax",
3419constructor_args=(0,),
3420cpp_constructor_args="torch::nn::SoftmaxOptions(0)",
3421input_size=(),
3422reference_fn=lambda i, *_: torch.exp(i).div(torch.exp(i).sum(0, True)),
3423desc="scalar",
3424),
3425dict(
3426module_name="LogSoftmax",
3427constructor_args=(0,),
3428cpp_constructor_args="torch::nn::LogSoftmaxOptions(0)",
3429input_size=(),
3430reference_fn=lambda i, *_: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
3431desc="multiparam_scalar",
3432),
3433dict(
3434module_name="ELU",
3435constructor_args=(2.0,),
3436cpp_constructor_args="torch::nn::ELUOptions().alpha(2.)",
3437input_size=(),
3438desc="scalar",
3439),
3440dict(
3441module_name="Hardshrink",
3442constructor_args=(2.0,),
3443cpp_constructor_args="torch::nn::HardshrinkOptions(2.)",
3444input_size=(),
3445desc="scalar",
3446),
3447dict(
3448module_name="LeakyReLU",
3449constructor_args=(0.5,),
3450cpp_constructor_args="torch::nn::LeakyReLUOptions().negative_slope(0.5)",
3451input_size=(),
3452check_inplace=True,
3453desc="with_negval_scalar",
3454),
3455dict(
3456module_name="LogSigmoid",
3457input_size=(),
3458reference_fn=lambda i, *_: i.sigmoid().log(),
3459desc="scalar",
3460),
3461dict(
3462module_name="Softplus",
3463constructor_args=(2, -100),
3464cpp_constructor_args="torch::nn::SoftplusOptions().beta(2).threshold(-100)",
3465input_size=(),
3466reference_fn=(
3467lambda i, *_: ((i * 2) > -100).type_as(i) * i
3468+ ((i * 2) <= -100).type_as(i) * 1.0 / 2.0 * torch.log(1 + torch.exp(2 * i))
3469),
3470desc="beta_threshold_scalar",
3471),
3472dict(
3473module_name="Softshrink",
3474constructor_args=(1,),
3475cpp_constructor_args="torch::nn::SoftshrinkOptions(1)",
3476input_size=(),
3477desc="lambda_scalar",
3478),
3479dict(
3480module_name="PReLU",
3481input_size=(),
3482reference_fn=lambda i, p, _: torch.clamp(i, min=0)
3483+ torch.clamp(i, max=0) * p[0][0],
3484desc="scalar",
3485),
3486dict(
3487module_name="Softsign",
3488input_size=(),
3489reference_fn=lambda i, *_: i.div(1 + torch.abs(i)),
3490desc="scalar",
3491),
3492dict(
3493module_name="Softmin",
3494constructor_args=(0,),
3495cpp_constructor_args="torch::nn::SoftminOptions(0)",
3496input_size=(),
3497desc="scalar",
3498),
3499dict(
3500module_name="Tanhshrink",
3501input_size=(),
3502desc="scalar",
3503),
3504dict(
3505fullname="Padding12_1dcircular",
3506constructor=wrap_functional(F.pad, pad=(1, 2), mode="circular"),
3507cpp_options_args="F::PadFuncOptions({1, 2}).mode(torch::kCircular)",
3508input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3509reference_fn=lambda i, *_: padding1d_circular(i, (1, 2)),
3510skip_double=TEST_WITH_ROCM,
3511pickle=False,
3512),
3513dict(
3514fullname="Padding31_1dcircular",
3515constructor=wrap_functional(F.pad, pad=(3, 1), mode="circular"),
3516cpp_options_args="F::PadFuncOptions({3, 1}).mode(torch::kCircular)",
3517input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3518reference_fn=lambda i, *_: padding1d_circular(i, (3, 1)),
3519skip_double=TEST_WITH_ROCM,
3520pickle=False,
3521),
3522dict(
3523fullname="Padding33_1dcircular",
3524constructor=wrap_functional(F.pad, pad=(3, 3), mode="circular"),
3525cpp_options_args="F::PadFuncOptions({3, 3}).mode(torch::kCircular)",
3526input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
3527reference_fn=lambda i, *_: padding1d_circular(i, (3, 3)),
3528skip_double=TEST_WITH_ROCM,
3529pickle=False,
3530),
3531dict(
3532fullname="Padding1221_2dcircular",
3533constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode="circular"),
3534cpp_options_args="F::PadFuncOptions({1, 2, 2, 1}).mode(torch::kCircular)",
3535input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape(
3536[1, 1, 2, 3]
3537),
3538reference_fn=lambda i, *_: padding2d_circular(i, (1, 2, 2, 1)),
3539skip_double=TEST_WITH_ROCM,
3540pickle=False,
3541),
3542dict(
3543fullname="Padding2322_2dcircular",
3544constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode="circular"),
3545cpp_options_args="F::PadFuncOptions({2, 3, 2, 2}).mode(torch::kCircular)",
3546input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape(
3547[1, 1, 2, 3]
3548),
3549reference_fn=lambda i, *_: padding2d_circular(i, (2, 3, 2, 2)),
3550skip_double=TEST_WITH_ROCM,
3551pickle=False,
3552),
3553dict(
3554fullname="Padding3331_2dcircular",
3555constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode="circular"),
3556cpp_options_args="F::PadFuncOptions({3, 3, 3, 1}).mode(torch::kCircular)",
3557input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape(
3558[1, 1, 3, 3]
3559),
3560reference_fn=lambda i, *_: padding2d_circular(i, (3, 3, 3, 1)),
3561skip_double=TEST_WITH_ROCM,
3562pickle=False,
3563),
3564dict(
3565fullname="Padding122112_3dcircular",
3566constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode="circular"),
3567cpp_options_args="F::PadFuncOptions({1, 2, 2, 1, 1, 2}).mode(torch::kCircular)",
3568input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3569[1, 1, 2, 2, 3]
3570),
3571reference_fn=lambda i, *_: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
3572skip_double=TEST_WITH_ROCM,
3573pickle=False,
3574),
3575dict(
3576fullname="Padding322112_3dcircular",
3577constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode="circular"),
3578cpp_options_args="F::PadFuncOptions({3, 2, 2, 1, 1, 2}).mode(torch::kCircular)",
3579input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3580[1, 1, 2, 2, 3]
3581),
3582reference_fn=lambda i, *_: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
3583skip_double=TEST_WITH_ROCM,
3584pickle=False,
3585),
3586dict(
3587fullname="Padding332122_3dcircular",
3588constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode="circular"),
3589cpp_options_args="F::PadFuncOptions({3, 3, 2, 1, 2, 2}).mode(torch::kCircular)",
3590input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape(
3591[1, 1, 2, 2, 3]
3592),
3593reference_fn=lambda i, *_: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
3594skip_double=TEST_WITH_ROCM,
3595pickle=False,
3596),
3597]
3598
3599# add conv padding mode tests:
3600for padding_mode, cpp_padding_mode in zip(
3601["reflect", "circular", "replicate", "zeros"],
3602["torch::kReflect", "torch::kCircular", "torch::kReplicate", "torch::kZeros"],
3603):
3604# conv signature:
3605# in_channels, out_channels, kernel_size, stride=1,
3606# padding=0, dilation=1, groups=1,
3607# bias=True, padding_mode='zeros'
3608for d in (1, 2, 3):
3609if d == 3 and padding_mode == "reflect":
3610# FIXME: remove after implementing reflection pad 3d
3611# https://github.com/pytorch/pytorch/issues/27655
3612continue
3613new_module_tests.append(
3614dict(
3615module_name="Conv{}d".format(d),
3616constructor_args=(3, 4, 3, 2, 2, 1, 1, True, padding_mode),
3617cpp_constructor_args="""torch::nn::Conv{}dOptions(3, 4, 3)
3618.stride(2)
3619.padding(2)
3620.dilation(1)
3621.groups(1)
3622.bias(true)
3623.padding_mode({})""".format(
3624d, cpp_padding_mode
3625),
3626input_size=(2, 3) + (3,) * d,
3627output_size=(2, 4) + (3,) * d,
3628cudnn=True,
3629desc="{}_stride2_pad2".format(padding_mode),
3630),
3631)
3632
3633
3634def kldivloss_reference(input, target, reduction="mean"):
3635safe_target = target * (target > 0).type_as(target)
3636safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
3637result = safe_target * (safe_target_log - input)
3638if reduction == "mean":
3639return result.mean()
3640elif reduction == "sum":
3641return result.sum()
3642elif reduction == "batchmean" and results.dim() != 0:
3643return result.sum() / result.size(0)
3644return result
3645
3646
3647def nlllossNd_reference(
3648input, target, weight=None, ignore_index=-100, reduction="mean"
3649):
3650assert input.dim() >= 3
3651N = input.size(0)
3652C = input.size(1)
3653out_size = (N,) + input.size()[2:]
3654output = torch.zeros(out_size).type_as(input)
3655
3656if weight is None:
3657weight = torch.ones(C).type_as(input)
3658total_weight = 0
3659for tup in product(*[range(size) for size in out_size]):
3660t_nx = target[tup]
3661norm = 0.0 if ignore_index == t_nx else weight[t_nx].item()
3662input_index = list(tup)
3663input_index.insert(1, t_nx)
3664output[tup] = -input[tuple(input_index)] * norm
3665total_weight += norm
3666
3667if reduction == "mean":
3668return output.sum() / total_weight
3669elif reduction == "sum":
3670return output.sum()
3671return output
3672
3673
3674def nllloss_reference(input, target, weight=None, ignore_index=-100, reduction="mean"):
3675def nll_loss_helper(input, target, weight, ignore_index):
3676if target == ignore_index:
3677return (0, 0)
3678norm = 1 if weight is None else weight[target]
3679result = -input[target] * norm
3680return (result, norm)
3681
3682losses_and_weights = [
3683nll_loss_helper(i, t, weight, ignore_index) for i, t in zip(input, target)
3684]
3685losses, weights = zip(*losses_and_weights)
3686losses_tensor = input.new_tensor(losses)
3687if reduction == "mean":
3688return sum(losses_tensor) / sum(weights)
3689elif reduction == "sum":
3690return sum(losses_tensor)
3691else:
3692return losses_tensor
3693
3694
3695def smoothl1loss_reference(input, target, reduction="mean"):
3696abs_diff = (input - target).abs()
3697ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
3698lt_one_mask = (abs_diff < 1).type_as(abs_diff)
3699output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff**2)
3700if reduction == "mean":
3701return output.mean()
3702elif reduction == "sum":
3703return output.sum()
3704return output
3705
3706
3707def _multilabelmarginloss_reference(input, target):
3708targets = []
3709for target_index in target:
3710if target_index < 0:
3711break
3712targets.append(target_index)
3713
3714sum = 0
3715for target_index in targets:
3716for i in range(0, len(input)):
3717if i not in targets:
3718sum += max(0, 1 - input[target_index] + input[i])
3719
3720return sum
3721
3722
3723def multilabelmarginloss_reference(input, target, reduction="mean"):
3724# make everything 2-dimensional
3725input_dim = input.dim()
3726if input.dim() < 2:
3727assert target.dim() < 2
3728input = (
3729input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
3730)
3731target = (
3732target.unsqueeze(0)
3733if target.dim() == 1
3734else target.unsqueeze(0).unsqueeze(0)
3735)
3736
3737n = input.size(0)
3738dim = input.size(1)
3739output = input.new(n).zero_()
3740for i in range(0, n):
3741output[i] = _multilabelmarginloss_reference(input[i], target[i])
3742
3743if reduction == "mean":
3744return output.mean() / dim
3745elif reduction == "sum":
3746return output.sum() / dim
3747elif input_dim < 2:
3748# we know we have (1, C) X (1, C) -> (1,), so squeeze will get us
3749# back to correct dimensionality
3750return output.squeeze() / dim
3751else:
3752return output / dim
3753
3754
3755def hingeembeddingloss_reference(input, target, margin=1.0, reduction="mean"):
3756margin_clamp = (margin - input).clamp(min=0).type_as(input)
3757output = torch.where(target == 1, input, margin_clamp)
3758
3759if reduction == "mean":
3760return output.mean()
3761elif reduction == "sum":
3762return output.sum()
3763return output
3764
3765
3766def softmarginloss_reference(input, target, reduction="mean"):
3767output = (1 + (-input * target).exp()).log()
3768
3769if reduction == "mean":
3770return output.mean()
3771elif reduction == "sum":
3772return output.sum()
3773return output
3774
3775
3776def _multimarginloss_reference(input, target_idx, p, margin, weight):
3777if weight is None:
3778weight = input.new(len(input)).fill_(1)
3779
3780output = 0
3781for i in range(0, len(input)):
3782if i != target_idx:
3783output += max(
37840, weight[target_idx] * (margin - input[target_idx] + input[i]) ** p
3785)
3786return output
3787
3788
3789def multimarginloss_reference(
3790input, target, p=1, margin=1, weight=None, reduction="mean"
3791):
3792if input.dim() < 2:
3793input = (
3794input.unsqueeze(0) if input.dim() == 1 else input.unsqueeze(0).unsqueeze(0)
3795)
3796
3797target_dim = target.dim()
3798if target.dim() == 0:
3799target = target.unsqueeze(0)
3800
3801n = input.size(0)
3802dim = input.size(1)
3803output = input.new(n)
3804for x in range(0, n):
3805output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
3806
3807if reduction == "mean":
3808return output.mean() / dim
3809elif reduction == "sum":
3810return output.sum() / dim
3811elif target_dim == 0:
3812return output.squeeze(0) / dim
3813return output / dim
3814
3815
3816def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction="mean"):
3817def _cos(a, b):
3818cos = a.new(a.size(0))
3819for i in range(0, a.size(0)):
3820cos[i] = (a[i] * b[i]).sum() / (
3821(((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5
3822)
3823return cos
3824
3825output = torch.where(
3826target == 1,
38271 - _cos(input1, input2),
3828(_cos(input1, input2) - margin).clamp(min=0),
3829)
3830
3831if reduction == "mean":
3832return output.mean()
3833elif reduction == "sum":
3834return output.sum()
3835return output
3836
3837
3838def tripletmarginloss_reference(
3839anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False, reduction="mean"
3840):
3841d_p = torch.pairwise_distance(anchor, positive, p, eps)
3842d_n = torch.pairwise_distance(anchor, negative, p, eps)
3843if swap:
3844d_s = torch.pairwise_distance(positive, negative, p, eps)
3845d_n = torch.min(d_n, d_s)
3846
3847output = torch.clamp(margin + d_p - d_n, min=0.0)
3848if reduction == "mean":
3849return output.mean()
3850elif reduction == "sum":
3851return output.sum()
3852return output
3853
3854
3855def marginrankingloss_reference(input1, input2, target, margin=0, reduction="mean"):
3856output = (-target * (input1 - input2) + margin).clamp(min=0)
3857if reduction == "mean":
3858return output.mean()
3859elif reduction == "sum":
3860return output.sum()
3861return output
3862
3863
3864# this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space
3865def ctcloss_reference(
3866log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean"
3867):
3868input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
3869target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
3870dt = log_probs.dtype
3871log_probs = log_probs.double() # we need the accuracy as we are not in logspace
3872targets = targets.long()
3873cum_target_lengths = target_lengths.cumsum(0)
3874losses = []
3875for i in range(log_probs.size(1)):
3876input_length = input_lengths[i].item()
3877target_length = target_lengths[i].item()
3878cum_target_length = cum_target_lengths[i].item()
3879targets_prime = targets.new_full((2 * target_length + 1,), blank)
3880if targets.dim() == 2:
3881targets_prime[1::2] = targets[i, :target_length]
3882else:
3883targets_prime[1::2] = targets[
3884cum_target_length - target_length : cum_target_length
3885]
3886probs = log_probs[:input_length, i].exp()
3887alpha = log_probs.new_zeros((target_length * 2 + 1,))
3888alpha[0] = probs[0, blank]
3889alpha[1] = probs[0, targets_prime[1]]
3890mask_third = targets_prime[:-2] != targets_prime[2:]
3891for t in range(1, input_length):
3892alpha_next = alpha.clone()
3893alpha_next[1:] += alpha[:-1]
3894alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
3895alpha = probs[t, targets_prime] * alpha_next
3896losses.append(-alpha[-2:].sum().log()[None])
3897output = torch.cat(losses, 0)
3898if reduction == "mean":
3899return (
3900output / target_lengths.to(dtype=output.dtype, device=output.device)
3901).mean()
3902elif reduction == "sum":
3903return output.sum()
3904output = output.to(dt)
3905return output
3906
3907
3908def padding1d_circular(input, pad):
3909r"""input:
3910[[[0., 1., 2.],
3911[3., 4., 5.]]]
3912pad: (1, 2)
3913output:
3914[[[2., 0., 1., 2., 0., 1.],
3915[5., 3., 4., 5., 3., 4.]]]
3916"""
3917return torch.cat([input[:, :, -pad[0] :], input, input[:, :, 0 : pad[1]]], dim=2)
3918
3919
3920def padding2d_circular(input, pad):
3921r"""input:
3922[[[[0., 1., 2],
3923[3., 4., 5.]]]]
3924pad: (1, 2, 2, 1)
3925output:
3926[[[[2., 0., 1., 2., 0., 1.],
3927[5., 3., 4., 5., 3., 4.],
3928[2., 0., 1., 2., 0., 1.],
3929[5., 3., 4., 5., 3., 4.],
3930[2., 0., 1., 2., 0., 1.]]]]
3931"""
3932input = torch.cat([input[:, :, -pad[2] :], input, input[:, :, 0 : pad[3]]], dim=2)
3933return torch.cat(
3934[input[:, :, :, -pad[0] :], input, input[:, :, :, 0 : pad[1]]], dim=3
3935)
3936
3937
3938def padding3d_circular(input, pad):
3939r"""input:
3940[[[[[ 0., 1., 2.],
3941[ 3., 4., 5.]],
3942[[ 6., 7., 8.],
3943[ 9., 10., 11.]]]]]
3944pad: (1, 2, 2, 1, 1, 2)
3945output: [[[[[ 8., 6., 7., 8., 6., 7.],
3946[11., 9., 10., 11., 9., 10.],
3947[ 8., 6., 7., 8., 6., 7.],
3948[11., 9., 10., 11., 9., 10.],
3949[ 8., 6., 7., 8., 6., 7.]],
3950
3951[[ 2., 0., 1., 2., 0., 1.],
3952[ 5., 3., 4., 5., 3., 4.],
3953[ 2., 0., 1., 2., 0., 1.],
3954[ 5., 3., 4., 5., 3., 4.],
3955[ 2., 0., 1., 2., 0., 1.]],
3956
3957[[ 8., 6., 7., 8., 6., 7.],
3958[11., 9., 10., 11., 9., 10.],
3959[ 8., 6., 7., 8., 6., 7.],
3960[11., 9., 10., 11., 9., 10.],
3961[ 8., 6., 7., 8., 6., 7.]],
3962
3963[[ 2., 0., 1., 2., 0., 1.],
3964[ 5., 3., 4., 5., 3., 4.],
3965[ 2., 0., 1., 2., 0., 1.],
3966[ 5., 3., 4., 5., 3., 4.],
3967[ 2., 0., 1., 2., 0., 1.]],
3968
3969[[ 8., 6., 7., 8., 6., 7.],
3970[11., 9., 10., 11., 9., 10.],
3971[ 8., 6., 7., 8., 6., 7.],
3972[11., 9., 10., 11., 9., 10.],
3973[ 8., 6., 7., 8., 6., 7.]]]]]
3974"""
3975input = torch.cat([input[:, :, -pad[4] :], input, input[:, :, 0 : pad[5]]], dim=2)
3976input = torch.cat(
3977[input[:, :, :, -pad[2] :], input, input[:, :, :, 0 : pad[3]]], dim=3
3978)
3979return torch.cat(
3980[input[:, :, :, :, -pad[0] :], input, input[:, :, :, :, 0 : pad[1]]], dim=4
3981)
3982
3983
3984loss_reference_fns = {
3985"KLDivLoss": kldivloss_reference,
3986"NLLLoss": nllloss_reference,
3987"NLLLossNd": nlllossNd_reference,
3988"SmoothL1Loss": smoothl1loss_reference,
3989"MultiLabelMarginLoss": multilabelmarginloss_reference,
3990"HingeEmbeddingLoss": hingeembeddingloss_reference,
3991"SoftMarginLoss": softmarginloss_reference,
3992"MultiMarginLoss": multimarginloss_reference,
3993"CosineEmbeddingLoss": cosineembeddingloss_reference,
3994"TripletMarginLoss": tripletmarginloss_reference,
3995"MarginRankingLoss": marginrankingloss_reference,
3996"CTCLoss": ctcloss_reference,
3997}
3998
3999
4000criterion_tests = [
4001dict(
4002module_name="L1Loss",
4003input_size=(2, 3, 4),
4004target_size=(2, 3, 4),
4005reference_fn=lambda i, t, _: 1.0
4006/ i.numel()
4007* sum((a - b).abs().sum() for a, b in zip(i, t)),
4008),
4009dict(
4010module_name="NLLLoss",
4011input_fn=lambda: torch.rand(15, 10).log(),
4012target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4013reference_fn=lambda i, t, m: nllloss_reference(
4014i, t, reduction=get_reduction(m)
4015),
4016check_sum_reduction=True,
4017check_bfloat16=TEST_WITH_ROCM,
4018),
4019dict(
4020module_name="NLLLoss",
4021constructor_args=(None, None, 2),
4022cpp_constructor_args="torch::nn::NLLLossOptions().weight({}).ignore_index(2)",
4023input_fn=lambda: torch.rand(15, 10).log(),
4024target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4025reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
4026desc="ignore_index",
4027check_bfloat16=TEST_WITH_ROCM,
4028),
4029dict(
4030module_name="NLLLoss",
4031constructor_args_fn=lambda: (torch.rand(10),),
4032cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10))",
4033input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4034target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4035reference_fn=lambda i, t, m: nllloss_reference(i, t, weight=get_weight(m)),
4036desc="weights",
4037check_bfloat16=TEST_WITH_ROCM,
4038),
4039dict(
4040module_name="NLLLoss",
4041constructor_args_fn=lambda: (torch.rand(10), None, 2),
4042cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(2)",
4043input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4044target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4045reference_fn=lambda i, t, m: nllloss_reference(
4046i, t, weight=get_weight(m), ignore_index=2
4047),
4048desc="weights_ignore_index",
4049check_bfloat16=TEST_WITH_ROCM,
4050),
4051dict(
4052module_name="NLLLoss",
4053constructor_args_fn=lambda: (torch.rand(10), None, -1),
4054cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(10)).ignore_index(-1)",
4055input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
4056target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
4057reference_fn=lambda i, t, m: nllloss_reference(
4058i, t, weight=get_weight(m), ignore_index=-1
4059),
4060desc="weights_ignore_index_neg",
4061check_bfloat16=TEST_WITH_ROCM,
4062),
4063dict(
4064module_name="KLDivLoss",
4065input_fn=lambda: torch.rand(10, 10).log(),
4066target_fn=lambda: torch.rand(10, 10),
4067reference_fn=lambda i, t, m: kldivloss_reference(i, t, get_reduction(m)),
4068check_sum_reduction=True,
4069),
4070dict(
4071module_name="MSELoss",
4072input_size=(2, 3, 4, 5),
4073target_size=(2, 3, 4, 5),
4074reference_fn=lambda i, t, m: (
4075(i - t).abs().pow(2).sum()
4076/ (i.numel() if get_reduction(m) == "mean" else 1)
4077),
4078check_sum_reduction=True,
4079),
4080dict(
4081module_name="BCELoss",
4082input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4083target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4084reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum()
4085/ (i.numel() if get_reduction(m) else 1),
4086check_gradgrad=False,
4087check_bfloat16=TEST_WITH_ROCM,
4088),
4089dict(
4090module_name="BCELoss",
4091constructor_args_fn=lambda: (torch.rand(10),),
4092cpp_constructor_args="torch::nn::BCELossOptions().weight(torch::rand(10))",
4093input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4094target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4095reference_fn=lambda i, t, m: -(
4096(t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)
4097).sum()
4098/ (i.numel() if get_reduction(m) else 1),
4099desc="weights",
4100check_gradgrad=False,
4101check_bfloat16=TEST_WITH_ROCM,
4102),
4103dict(
4104module_name="CrossEntropyLoss",
4105input_size=(15, 10),
4106target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4107),
4108dict(
4109module_name="CrossEntropyLoss",
4110constructor_args_fn=lambda: (torch.rand(10),),
4111cpp_constructor_args="torch::nn::CrossEntropyLossOptions().weight(torch::rand(10))",
4112input_size=(15, 10),
4113target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
4114desc="weights",
4115),
4116dict(
4117module_name="HingeEmbeddingLoss",
4118input_size=(10,),
4119target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
4120reference_fn=lambda i, t, m: hingeembeddingloss_reference(
4121i, t, reduction=get_reduction(m)
4122),
4123check_sum_reduction=True,
4124),
4125dict(
4126module_name="HingeEmbeddingLoss",
4127constructor_args=(0.5,),
4128cpp_constructor_args="torch::nn::HingeEmbeddingLossOptions().margin(0.5)",
4129input_size=(10,),
4130target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
4131reference_fn=lambda i, t, m: hingeembeddingloss_reference(
4132i, t, margin=0.5, reduction=get_reduction(m)
4133),
4134desc="margin",
4135check_sum_reduction=True,
4136),
4137dict(
4138module_name="MultiLabelMarginLoss",
4139input_size=(10,),
4140target_fn=lambda: torch.rand(10).mul(10).floor().long(),
4141reference_fn=lambda i, t, m: multilabelmarginloss_reference(
4142i, t, reduction=get_reduction(m)
4143),
4144desc="1d",
4145check_sum_reduction=True,
4146check_gradgrad=False,
4147check_bfloat16=TEST_WITH_ROCM,
4148),
4149dict(
4150module_name="MultiLabelMarginLoss",
4151input_size=(5, 10),
4152target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
4153reference_fn=lambda i, t, m: multilabelmarginloss_reference(
4154i, t, reduction=get_reduction(m)
4155),
4156check_sum_reduction=True,
4157check_gradgrad=False,
4158check_bfloat16=TEST_WITH_ROCM,
4159),
4160dict(
4161module_name="MultiLabelSoftMarginLoss",
4162input_size=(5, 10),
4163target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
4164reference_fn=lambda i, t, m: -(
4165t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()
4166).sum()
4167/ i.numel(),
4168check_gradgrad=False,
4169),
4170dict(
4171module_name="MultiMarginLoss",
4172input_size=(5, 10),
4173target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4174reference_fn=lambda i, t, m: multimarginloss_reference(
4175i, t, reduction=get_reduction(m)
4176),
4177check_sum_reduction=True,
4178check_gradgrad=False,
4179),
4180dict(
4181module_name="MultiMarginLoss",
4182input_size=(10,),
4183target_fn=lambda: torch.rand(1).mul(8).floor().long(),
4184reference_fn=lambda i, t, m: multimarginloss_reference(
4185i, t, reduction=get_reduction(m)
4186),
4187desc="1d",
4188check_sum_reduction=True,
4189check_gradgrad=False,
4190),
4191dict(
4192module_name="MultiMarginLoss",
4193constructor_args=(2,),
4194cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(2)",
4195input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
4196target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4197reference_fn=lambda i, t, m: multimarginloss_reference(
4198i, t, p=2, reduction=get_reduction(m)
4199),
4200desc="p",
4201check_sum_reduction=True,
4202check_gradgrad=False,
4203),
4204dict(
4205module_name="MultiMarginLoss",
4206constructor_args=(1, 0.5),
4207cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(1).margin(0.5)",
4208legacy_constructor_args=(1, None, 0.5),
4209input_size=(5, 10),
4210target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4211reference_fn=lambda i, t, m: multimarginloss_reference(
4212i, t, margin=0.5, reduction=get_reduction(m)
4213),
4214desc="margin",
4215check_sum_reduction=True,
4216check_gradgrad=False,
4217),
4218dict(
4219module_name="MultiMarginLoss",
4220constructor_args=(1, 1.0, torch.rand(10)),
4221cpp_constructor_args="torch::nn::MultiMarginLossOptions().p(1).margin(1.).weight(torch::rand(10))",
4222legacy_constructor_args=(1, torch.rand(10)),
4223input_size=(5, 10),
4224target_fn=lambda: torch.rand(5).mul(8).floor().long(),
4225reference_fn=lambda i, t, m: multimarginloss_reference(
4226i, t, weight=get_weight(m), reduction=get_reduction(m)
4227),
4228desc="weights",
4229check_sum_reduction=True,
4230check_gradgrad=False,
4231),
4232dict(
4233module_name="SmoothL1Loss",
4234input_size=(5, 10),
4235target_size=(5, 10),
4236check_sum_reduction=True,
4237reference_fn=lambda i, t, m: smoothl1loss_reference(
4238i, t, reduction=get_reduction(m)
4239),
4240),
4241dict(
4242module_name="SoftMarginLoss",
4243input_size=(5, 5),
4244target_fn=lambda: torch.randn(5, 5).sign(),
4245reference_fn=lambda i, t, m: softmarginloss_reference(
4246i, t, reduction=get_reduction(m)
4247),
4248check_sum_reduction=True,
4249),
4250dict(
4251module_name="CosineEmbeddingLoss",
4252input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
4253target_fn=lambda: torch.randn(15).sign(),
4254reference_fn=lambda i, t, m: cosineembeddingloss_reference(
4255i[0], i[1], t, reduction=get_reduction(m)
4256),
4257check_sum_reduction=True,
4258),
4259dict(
4260module_name="CosineEmbeddingLoss",
4261constructor_args=(0.7,),
4262cpp_constructor_args="torch::nn::CosineEmbeddingLossOptions().margin(0.7)",
4263input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
4264target_fn=lambda: torch.randn(15).sign(),
4265reference_fn=lambda i, t, m: cosineembeddingloss_reference(
4266i[0], i[1], t, margin=0.7, reduction=get_reduction(m)
4267),
4268desc="margin",
4269check_sum_reduction=True,
4270),
4271dict(
4272module_name="MarginRankingLoss",
4273input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
4274target_fn=lambda: torch.randn(50).sign(),
4275reference_fn=lambda i, t, m: marginrankingloss_reference(
4276i[0], i[1], t, reduction=get_reduction(m)
4277),
4278check_sum_reduction=True,
4279),
4280dict(
4281module_name="MarginRankingLoss",
4282constructor_args=(0.5,),
4283cpp_constructor_args="torch::nn::MarginRankingLossOptions().margin(0.5)",
4284input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
4285target_fn=lambda: torch.randn(50).sign(),
4286reference_fn=lambda i, t, m: marginrankingloss_reference(
4287i[0], i[1], t, margin=0.5, reduction=get_reduction(m)
4288),
4289desc="margin",
4290check_sum_reduction=True,
4291),
4292]
4293
4294new_criterion_tests = [
4295dict(
4296module_name="BCEWithLogitsLoss",
4297input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4298target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4299),
4300dict(
4301module_name="BCEWithLogitsLoss",
4302constructor_args=(torch.rand(10),),
4303cpp_constructor_args="torch::nn::BCEWithLogitsLossOptions().weight(torch::rand(10))",
4304input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
4305target_fn=lambda: torch.randn(15, 10).gt(0).double(),
4306desc="weights",
4307),
4308dict(
4309module_name="BCEWithLogitsLoss",
4310constructor_args=(torch.rand(()),),
4311cpp_constructor_args="torch::nn::BCEWithLogitsLossOptions().weight(torch::rand({}))",
4312input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
4313target_fn=lambda: torch.randn(()).gt(0).double(),
4314desc="scalar_weights",
4315),
4316dict(
4317module_name="NLLLoss",
4318input_size=(2, 3, 5, 5),
4319target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
4320reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4321i, t, reduction=get_reduction(m)
4322),
4323check_sum_reduction=True,
4324desc="2d",
4325check_bfloat16=TEST_WITH_ROCM,
4326),
4327dict(
4328module_name="NLLLoss",
4329constructor_args_fn=lambda: (torch.rand(3),),
4330cpp_constructor_args="torch::nn::NLLLossOptions().weight(torch::rand(3))",
4331input_size=(2, 3, 5, 5),
4332target=torch.rand(2, 5, 5).mul(3).floor().long(),
4333reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4334i, t, weight=get_weight(m)
4335),
4336desc="2d_weights",
4337check_bfloat16=TEST_WITH_ROCM,
4338),
4339dict(
4340module_name="NLLLoss",
4341constructor_args=(None, None, 1),
4342cpp_constructor_args="torch::nn::NLLLossOptions().weight({}).ignore_index(1)",
4343input_size=(2, 3, 5, 5),
4344target_fn=lambda: torch.rand(2, 5, 5).mul(3).floor().long(),
4345reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4346i, t, ignore_index=1
4347),
4348desc="2d_ignore_index",
4349check_bfloat16=TEST_WITH_ROCM,
4350),
4351dict(
4352module_name="NLLLoss",
4353input_size=(2, 3, 5, 5, 2, 2),
4354target_fn=lambda: torch.rand(2, 5, 5, 2, 2).mul(3).floor().long(),
4355reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4356i, t, reduction=get_reduction(m)
4357),
4358check_sum_reduction=True,
4359desc="higher_dim",
4360check_bfloat16=TEST_WITH_ROCM,
4361),
4362dict(
4363module_name="NLLLoss",
4364input_size=(2, 3, 5),
4365target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(),
4366reference_fn=lambda i, t, m: loss_reference_fns["NLLLossNd"](
4367i, t, reduction=get_reduction(m)
4368),
4369check_sum_reduction=True,
4370desc="dim_is_3",
4371check_bfloat16=TEST_WITH_ROCM,
4372),
4373dict(
4374module_name="PoissonNLLLoss", # Default is log_input=True, full=False
4375input_size=(2, 3, 4, 5),
4376target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4377reference_fn=lambda i, t, _: (i.exp() - t.mul(i)).mean(),
4378desc="no_full_loss",
4379),
4380dict(
4381module_name="PoissonNLLLoss",
4382constructor_args=(False, False), # log_input=False, full=False
4383cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(false).full(false)",
4384input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
4385target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4386reference_fn=lambda i, t, _: (i - t.mul((i + 1e-8).log())).mean(),
4387desc="no_full_loss_no_log_input",
4388),
4389dict(
4390module_name="PoissonNLLLoss",
4391constructor_args=(True, True), # log_input=True, full=True
4392cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(true).full(true)",
4393input_size=(2, 3, 4, 5),
4394target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4395reference_fn=lambda i, t, _: (
4396i.exp()
4397- t.mul(i)
4398+ (t.mul(t.log()) - t + 0.5 * (2.0 * pi * t).log()).masked_fill(t <= 1, 0)
4399).mean(),
4400desc="full_loss",
4401),
4402dict(
4403module_name="PoissonNLLLoss",
4404constructor_args=(False, True), # log_input=False, full=True
4405cpp_constructor_args="torch::nn::PoissonNLLLossOptions().log_input(false).full(true)",
4406input_fn=lambda: torch.randn(2, 3, 4, 5).abs_().add_(0.001),
4407target_fn=lambda: torch.randn(2, 3, 4, 5).floor_().abs_(),
4408reference_fn=lambda i, t, _: (
4409i
4410- t.mul((i + 1e-8).log())
4411+ (t.mul(t.log()) - t + 0.5 * (2.0 * pi * t).log()).masked_fill(t <= 1, 0)
4412).mean(),
4413desc="full_loss_no_log_input",
4414),
4415dict(
4416module_name="L1Loss",
4417input_size=(),
4418target_size=(),
4419reference_fn=lambda i, t, _: 1.0 / i.numel() * (i - t).abs().sum(),
4420desc="scalar",
4421),
4422dict(
4423module_name="KLDivLoss",
4424input_fn=lambda: torch.rand(()).log(),
4425target_fn=lambda: torch.rand(()),
4426reference_fn=lambda i, t, m: kldivloss_reference(i, t, get_reduction(m)),
4427check_sum_reduction=True,
4428desc="scalar",
4429),
4430dict(
4431module_name="MSELoss",
4432input_size=(),
4433target_size=(),
4434reference_fn=lambda i, t, m: (
4435(i - t).abs().pow(2).sum()
4436/ (i.numel() if get_reduction(m) == "mean" else 1)
4437),
4438check_sum_reduction=True,
4439desc="scalar",
4440check_bfloat16=TEST_WITH_ROCM,
4441),
4442dict(
4443module_name="MSELoss",
4444input_fn=lambda: torch.ones(5, 68, 64, 64, dtype=torch.float) / 10,
4445target_fn=lambda: torch.zeros(5, 68, 64, 64, dtype=torch.float),
4446reference_fn=lambda i, t, m: (
4447(i - t).abs().pow(2).sum()
4448/ (i.numel() if get_reduction(m) == "mean" else 1)
4449),
4450check_forward_only=True,
4451desc="prec",
4452check_bfloat16=TEST_WITH_ROCM,
4453),
4454dict(
4455module_name="BCELoss",
4456constructor_args_fn=lambda: (torch.rand(()),),
4457cpp_constructor_args="torch::nn::BCELossOptions().weight(torch::rand({}))",
4458input_fn=lambda: torch.rand(()).clamp_(1e-2, 1 - 1e-2),
4459target_fn=lambda: torch.rand(()).gt(0).double(),
4460reference_fn=lambda i, t, m: -(
4461(t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)
4462).sum()
4463/ (i.numel() if get_reduction(m) == "mean" else 1),
4464desc="scalar_weights",
4465check_gradgrad=False,
4466check_bfloat16=TEST_WITH_ROCM,
4467),
4468dict(
4469module_name="HingeEmbeddingLoss",
4470constructor_args=(0.5,),
4471cpp_constructor_args="torch::nn::HingeEmbeddingLossOptions().margin(0.5)",
4472input_size=(),
4473target_fn=lambda: torch.randn(()).gt(0).double().mul_(2).sub(1),
4474desc="scalar_margin",
4475check_sum_reduction=True,
4476),
4477dict(
4478module_name="SmoothL1Loss",
4479input_size=(),
4480target_size=(),
4481check_sum_reduction=True,
4482reference_fn=lambda i, t, m: smoothl1loss_reference(
4483i, t, reduction=get_reduction(m)
4484),
4485desc="scalar",
4486),
4487dict(
4488module_name="MultiLabelSoftMarginLoss",
4489constructor_args=(torch.rand(10),),
4490cpp_constructor_args="torch::nn::MultiLabelSoftMarginLossOptions().weight(torch::rand(10))",
4491input_fn=lambda: torch.randn(5, 10),
4492target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
4493reference_fn=lambda i, t, m: -(
4494(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * get_weight(m)
4495).sum()
4496/ (
4497i.numel()
4498if get_reduction(m) == "mean"
4499else i.size(1)
4500if get_reduction(m) == "sum"
4501else 1
4502),
4503desc="weights",
4504check_sum_reduction=True,
4505check_gradgrad=False,
4506),
4507dict(
4508module_name="CTCLoss",
4509constructor_args=(14,), # blank=14
4510extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
4511input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4512target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4513reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4514i, t, il, tl, blank=14, reduction=get_reduction(m)
4515),
4516desc="lengths_intlists",
4517check_sum_reduction=True,
4518check_gradgrad=False,
4519check_half=False,
4520# `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
4521test_cpp_api_parity=False,
4522),
4523dict(
4524module_name="CTCLoss",
4525constructor_args=(14,), # blank=14
4526cpp_constructor_args="torch::nn::CTCLossOptions().blank(14)",
4527extra_args=(
4528torch.tensor([50, 50, 50]),
4529torch.tensor([30, 25, 20]),
4530), # input_lengths, target_lengths
4531input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4532target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4533reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4534i, t, il, tl, blank=14, reduction=get_reduction(m)
4535),
4536desc="lengths_tensors",
4537check_sum_reduction=True,
4538check_gradgrad=False,
4539check_half=False,
4540),
4541# Test is flaky
4542# See https://github.com/pytorch/pytorch/issues/29380.
4543# dict(
4544# module_name='CTCLoss',
4545# desc='1d_target',
4546# constructor_args=(14,), # blank=14
4547# extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
4548# input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4549# target_fn=lambda: torch.randint(0, 14, (3, 30), dtype=torch.long),
4550# reference_fn=lambda i, t, il, tl, m:
4551# ctcloss_reference(i, t, il, tl, blank=14, reduction=get_reduction(m)),
4552# check_sum_reduction=True,
4553# check_gradgrad=False,
4554# check_half=False,
4555# ),
4556dict(
4557module_name="CTCLoss",
4558desc="2d_int_target_lengths_intlists",
4559constructor_args=(0,), # blank=0
4560extra_args=([50, 50, 50], [30, 25, 20]), # input_lengths, target_lengths
4561input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4562target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4563reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4564i, t, il, tl, blank=0, reduction=get_reduction(m)
4565),
4566check_sum_reduction=True,
4567check_gradgrad=False,
4568check_half=False,
4569convert_target=False,
4570# `CTCLoss` in C++ frontend doesn't accept integer list for `input_lengths` or `target_lengths`
4571test_cpp_api_parity=False,
4572),
4573dict(
4574module_name="CTCLoss",
4575desc="2d_int_target_lengths_tensors",
4576constructor_args=(0,), # blank=0
4577cpp_constructor_args="torch::nn::CTCLossOptions().blank(0)",
4578extra_args=(
4579torch.tensor([50, 50, 50]),
4580torch.tensor([30, 25, 20]),
4581), # input_lengths, target_lengths
4582input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4583target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4584reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4585i, t, il, tl, blank=0, reduction=get_reduction(m)
4586),
4587check_sum_reduction=True,
4588check_gradgrad=False,
4589check_half=False,
4590convert_target=False,
4591),
4592dict(
4593module_name="CTCLoss",
4594desc="2d_lengths_tensors",
4595constructor_args=(0,), # blank=0
4596cpp_constructor_args="torch::nn::CTCLossOptions().blank(0)",
4597extra_args=(
4598torch.tensor([50, 50, 50]),
4599torch.tensor([30, 25, 20]),
4600), # input_lengths, target_lengths
4601input_fn=lambda: torch.randn(50, 3, 15).log_softmax(2),
4602target_fn=lambda: torch.randint(1, 15, (3, 30), dtype=torch.int),
4603reference_fn=lambda i, t, il, tl, m: ctcloss_reference(
4604i, t, il, tl, blank=0, reduction=get_reduction(m)
4605),
4606check_sum_reduction=True,
4607check_gradgrad=False,
4608check_half=False,
4609convert_target=False,
4610),
4611]
4612
4613
4614class NNTestCase(TestCase):
4615def _jacobian(self, input, num_out):
4616if isinstance(input, tuple):
4617return tuple(self._jacobian(elem, num_out) for elem in input)
4618elif isinstance(input, list):
4619return [self._jacobian(elem, num_out) for elem in input]
4620else:
4621return torch.zeros(input.nelement(), num_out)
4622
4623def _flatten_tensors(self, x):
4624if isinstance(x, torch.Tensor):
4625if x.is_sparse:
4626return x.to_dense().view(-1)
4627else:
4628return x.view(-1)
4629else:
4630return tuple(self._flatten_tensors(a) for a in x)
4631
4632def _zero_grad_input(self, input):
4633if isinstance(input, torch.Tensor):
4634if input.requires_grad and input.grad is not None:
4635input.grad.zero_()
4636input.grad.detach_()
4637else:
4638for i in input:
4639self._zero_grad_input(i)
4640
4641def _analytical_jacobian(
4642self, module, input, jacobian_input=True, jacobian_parameters=True
4643):
4644output = self._forward(module, input)
4645output_size = output.nelement()
4646
4647if jacobian_input:
4648jacobian_inp = self._jacobian(input, output_size)
4649flat_jacobian_input = list(iter_tensors(jacobian_inp))
4650
4651if jacobian_parameters:
4652num_param = sum(p.numel() for p in self._get_parameters(module)[0])
4653jacobian_param = torch.zeros(num_param, output_size)
4654
4655for i in range(output_size):
4656param, d_param = self._get_parameters(module)
4657# make non grad zeros
4658d_param = [
4659torch.zeros_like(p) if d is None else d
4660for (p, d) in zip(param, d_param)
4661]
4662d_out = torch.zeros_like(output)
4663flat_d_out = d_out.view(-1)
4664flat_d_out[i] = 1
4665if jacobian_parameters:
4666self._zero_grad_parameters(module)
4667# Tensors will accumulate gradient from multiple steps
4668if jacobian_input:
4669self._zero_grad_input(input)
4670d_input = self._backward(module, input, output, d_out)
4671if jacobian_input:
4672for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
4673jacobian_x[:, i] = d_x.contiguous().view(-1)
4674if jacobian_parameters:
4675jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
4676
4677res = tuple()
4678if jacobian_input:
4679res += (jacobian_inp,)
4680if jacobian_parameters:
4681res += (jacobian_param,)
4682
4683return res
4684
4685def _numerical_jacobian(
4686self, module, input, jacobian_input=True, jacobian_parameters=True
4687):
4688def fw(input):
4689return self._forward(module, input).detach()
4690
4691res = tuple()
4692if jacobian_input:
4693res += (get_numerical_jacobian(fw, input, eps=1e-6),)
4694if jacobian_parameters:
4695param, _ = self._get_parameters(module)
4696res += (
4697torch.cat(
4698[get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0
4699),
4700)
4701return res
4702
4703def check_jacobian(self, module, input, jacobian_input=True):
4704jacobian_parameters = bool(self._get_parameters(module)[0])
4705analytical = self._analytical_jacobian(
4706module, input, jacobian_input, jacobian_parameters
4707)
4708numerical = self._numerical_jacobian(
4709module, input, jacobian_input, jacobian_parameters
4710)
4711analytical_t = list(iter_tensors(analytical))
4712numerical_t = list(iter_tensors(numerical))
4713
4714# TODO: compare structure
4715if input.numel() != 0:
4716self.assertLessEqual(
4717max(
4718a.add(n, alpha=-1).abs().max()
4719for a, n in zip(analytical_t, numerical_t)
4720),
4721PRECISION,
4722)
4723
4724def check_criterion_jacobian(self, criterion, input, target):
4725eps = 1e-6
4726self._forward_criterion(criterion, input, target)
4727analytical_d_x = self._backward_criterion(criterion, input, target)
4728numerical_d_x = deepcopy(analytical_d_x)
4729
4730input_t = iter_tensors(input)
4731numerical_t = iter_tensors(numerical_d_x)
4732for x, d_x in zip(input_t, numerical_t):
4733x = x.view(-1).data
4734d_x = d_x.view(-1).data
4735for i in range(x.nelement()):
4736original = x[i].item()
4737x[i] = original + eps
4738fx1 = self._forward_criterion(criterion, input, target)
4739x[i] = original - eps
4740fx2 = self._forward_criterion(criterion, input, target)
4741deriv = (fx1 - fx2) / (2.0 * eps)
4742d_x[i] = float(deriv)
4743x[i] = original
4744
4745# TODO: check structure
4746analytical_t = list(iter_tensors(analytical_d_x))
4747numerical_t = list(iter_tensors(numerical_d_x))
4748
4749self.assertLessEqual(
4750max(
4751a.add(n, alpha=-1).abs().max()
4752for a, n in zip(analytical_t, numerical_t)
4753),
4754PRECISION,
4755)
4756
4757
4758class TestBase(object):
4759_required_arg_names = {"constructor_args", "input", "extra_args"}
4760
4761def __init__(
4762self, constructor, desc="", reference_fn=None, fullname=None, **kwargs
4763):
4764self.desc = desc
4765self.fullname = fullname
4766self.constructor = constructor
4767self.reference_fn = reference_fn
4768for name in self._required_arg_names:
4769if (
4770name not in kwargs
4771and name + "_fn" not in kwargs
4772and name + "_size" not in kwargs
4773):
4774if name in {"constructor_args", "extra_args"}:
4775kwargs[name] = tuple()
4776else:
4777raise ValueError(
4778"{}: Specify {} by a value, a function to generate it, or it's size!".format(
4779self.get_name(), name
4780)
4781)
4782self._extra_kwargs = kwargs
4783self._arg_cache = {}
4784
4785def get_name(self):
4786if self.fullname is not None:
4787return "test_" + self.fullname
4788
4789test_name = "test_" + self.constructor.__name__
4790if self.desc:
4791test_name += "_" + self.desc
4792return test_name
4793
4794def _unpack(self, value):
4795if isinstance(value, torch.Tensor):
4796return value
4797elif is_iterable(value):
4798return type(value)(self._unpack(v) for v in value)
4799else:
4800return value
4801
4802@property
4803def constructor_args(self):
4804return self._get_arg("constructor_args", True)
4805
4806@property
4807def extra_args(self):
4808return self._get_arg("extra_args", True)
4809
4810def _get_arg(self, name, unpack):
4811assert name in self._required_arg_names
4812
4813if name not in self._arg_cache:
4814fn_name = name + "_fn"
4815size_name = name + "_size"
4816
4817if name in self._extra_kwargs:
4818self._arg_cache[name] = self._extra_kwargs[name]
4819elif fn_name in self._extra_kwargs:
4820self._arg_cache[name] = self._extra_kwargs[fn_name]()
4821else:
4822assert (
4823size_name in self._extra_kwargs
4824), "Missing `{}`, `{}` or `{}` for {}".format(
4825name, size_name, fn_name, self.get_name()
4826)
4827
4828def map_tensor_sizes(sizes):
4829if isinstance(sizes, list):
4830return [map_tensor_sizes(s) for s in sizes]
4831elif isinstance(sizes, torch.Tensor):
4832return sizes.double()
4833else:
4834return torch.randn(sizes)
4835
4836self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
4837
4838return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
4839
4840def _get_input(self, unpack=True):
4841return self._get_arg("input", unpack)
4842
4843def __call__(self, test_case):
4844raise NotImplementedError
4845
4846
4847class ModuleTest(TestBase):
4848def __init__(self, *args, **kwargs):
4849super(ModuleTest, self).__init__(*args, **kwargs)
4850self.jacobian_input = kwargs.get("jacobian_input", True)
4851self.should_test_cuda = kwargs.get("test_cuda", True)
4852self.should_test_pickle = kwargs.get("pickle", True)
4853self.check_gradgrad = kwargs.get("check_gradgrad", True)
4854self.FIXME_no_cuda_gradgrad_comparison = kwargs.get(
4855"FIXME_no_cuda_gradgrad_comparison", False
4856)
4857self.precision = kwargs.get("precision", 2e-4)
4858self.check_forward_only = kwargs.get("check_forward_only", False)
4859
4860def __call__(self, test_case):
4861module = self.constructor(*self.constructor_args).to("xpu")
4862input = self._get_input()
4863
4864if self.reference_fn is not None:
4865out = test_case._forward(module, input)
4866ref_input = deepcopy(input)
4867ref_module = deepcopy(module)
4868expected_out = self.reference_fn(
4869ref_input, test_case._get_parameters(module)[0], ref_module
4870)
4871test_case.assertEqual(out, expected_out)
4872unsupported_backward_modules = [
4873"Conv1d",
4874"Conv2d",
4875"Conv3d",
4876"ConvTranspose1d",
4877"ConvTranspose2d",
4878"ConvTranspose3d",
4879]
4880if (
4881module._get_name() in unsupported_backward_modules
4882and input.dtype == torch.float64
4883):
4884return
4885if self.check_forward_only:
4886return
4887self.test_noncontig(test_case, module, input)
4888
4889if self.should_test_pickle:
4890# TODO: do this with in-memory files as soon as torch.save will support it
4891with TemporaryFile() as f:
4892test_case._forward(module, input)
4893torch.save(module, f)
4894f.seek(0)
4895module_copy = torch.load(f)
4896test_case.assertEqual(
4897test_case._forward(module, input),
4898test_case._forward(module_copy, input),
4899)
4900
4901self._do_test(test_case, module, input)
4902
4903def noncontiguize(self, obj):
4904if isinstance(obj, list):
4905return [self.noncontiguize(o) for o in obj]
4906tensor = obj
4907ndim = tensor.dim()
4908# Always making only the last dimension noncontiguous is easy to hide
4909# bugs because .view(-1) will still work. So try to find a dim with size
4910# > 1 and make that non-contiguous, i.e., stack + select on the
4911# dimension directly after that.
4912dim = ndim
4913for d in range(ndim):
4914if tensor.size(d) > 1:
4915dim = d + 1
4916break
4917noncontig = (
4918torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
4919)
4920assert (
4921noncontig.numel() == 1
4922or noncontig.numel() == 0
4923or not noncontig.is_contiguous()
4924)
4925noncontig.requires_grad = tensor.requires_grad
4926return noncontig
4927
4928def test_noncontig(self, test_case, module, input):
4929# check no scalars, can't make non-contig
4930if isinstance(input, torch.Tensor) and input.dim() == 0:
4931return
4932if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
4933return
4934
4935test_case._zero_grad_parameters(module)
4936test_case._zero_grad_input(input)
4937with freeze_rng_state():
4938output = test_case._forward(module, input)
4939grad_output = output.new(output.shape).normal_()
4940output = output.clone()
4941d_input = deepcopy(test_case._backward(module, input, output, grad_output))
4942d_param = deepcopy(test_case._get_parameters(module)[1])
4943
4944nc_input = self.noncontiguize(input)
4945nc_grad_output = self.noncontiguize(grad_output)
4946for contig_i, contig_g in product((True, False), repeat=2):
4947i = input if contig_i else nc_input
4948# Some ops, e.g., nn.Flatten, return gradient that shares
4949# storage with the grad_output. Hence we copy here.
4950go = deepcopy(grad_output if contig_g else nc_grad_output)
4951test_case._zero_grad_parameters(module)
4952test_case._zero_grad_input(i)
4953with freeze_rng_state():
4954out = test_case._forward(module, i)
4955grad = test_case._backward(module, i, out, go)
4956
4957test_case.assertEqual(out, output)
4958test_case.assertEqual(grad, d_input, 1e-4)
4959test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
4960
4961def test_cuda(self, test_case):
4962if not TEST_CUDA or not self.should_test_cuda:
4963raise unittest.SkipTest("Excluded from CUDA tests")
4964try:
4965cpu_input = self._get_input()
4966type_map = {"torch.DoubleTensor": torch.cuda.FloatTensor}
4967gpu_input = to_gpu(cpu_input, type_map=type_map)
4968
4969cpu_module = self.constructor(*self.constructor_args)
4970gpu_module = self.constructor(*self.constructor_args).float().cuda()
4971cpu_param = test_case._get_parameters(cpu_module)
4972gpu_param = test_case._get_parameters(gpu_module)
4973for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
4974gpu_p.data.copy_(cpu_p)
4975
4976test_case._zero_grad_input(cpu_input)
4977test_case._zero_grad_input(gpu_input)
4978test_case._zero_grad_parameters(cpu_module)
4979test_case._zero_grad_parameters(gpu_module)
4980cpu_output = test_case._forward(cpu_module, cpu_input)
4981gpu_output = test_case._forward(gpu_module, gpu_input)
4982test_case.assertEqual(cpu_output, gpu_output, self.precision)
4983
4984# Run backwards on CPU and GPU and compare results
4985for _ in range(5):
4986cpu_gradOutput = cpu_output.clone().normal_()
4987gpu_gradOutput = cpu_gradOutput.type("torch.cuda.FloatTensor")
4988cpu_gradInput = test_case._backward(
4989cpu_module, cpu_input, cpu_output, cpu_gradOutput
4990)
4991gpu_gradInput = test_case._backward(
4992gpu_module, gpu_input, gpu_output, gpu_gradOutput
4993)
4994test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
4995for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
4996test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
4997
4998# Run double-backwards on CPU and GPU and compare results
4999if self.check_gradgrad and not self.FIXME_no_cuda_gradgrad_comparison:
5000cpu_output = cpu_module(cpu_input)
5001gpu_output = gpu_module(gpu_input)
5002
5003cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
5004gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
5005gpu_gradOutput.requires_grad = True
5006
5007cpu_gradInputs = torch.autograd.grad(
5008cpu_output,
5009(cpu_input,) + tuple(cpu_module.parameters()),
5010cpu_gradOutput,
5011create_graph=True,
5012)
5013gpu_gradInputs = torch.autograd.grad(
5014gpu_output,
5015(gpu_input,) + tuple(gpu_module.parameters()),
5016gpu_gradOutput,
5017create_graph=True,
5018)
5019
5020for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
5021test_case.assertEqual(cpu_d_i, gpu_d_i, self.precision)
5022
5023# We mix output into the second backwards computation so that
5024# torch.autograd.grad doesn't complain that some inputs
5025# are unreachable (which can happen if you differentiate
5026# only on the gradient.
5027cpu_gg = torch.autograd.grad(
5028cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)),
5029(cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
5030retain_graph=True,
5031)
5032gpu_gg = torch.autograd.grad(
5033gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)),
5034(gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
5035retain_graph=True,
5036)
5037
5038test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
5039for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
5040test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
5041
5042self.test_noncontig(test_case, gpu_module, gpu_input)
5043except NotImplementedError:
5044pass
5045# TODO: remove this after CUDA scatter_ is implemented
5046except AttributeError as e:
5047if (
5048len(e.args) == 1
5049and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]
5050):
5051pass
5052else:
5053raise
5054
5055
5056class CriterionTest(TestBase):
5057_required_arg_names = TestBase._required_arg_names.union({"target"})
5058
5059def __init__(self, *args, **kwargs):
5060super(CriterionTest, self).__init__(*args, **kwargs)
5061self.should_test_cuda = kwargs.get("test_cuda", True)
5062self.check_forward_only = kwargs.get("check_forward_only", True)
5063
5064def _get_target(self):
5065return self._get_arg("target", True)
5066
5067def __call__(self, test_case):
5068module = self.constructor(*self.constructor_args)
5069input = self._get_input()
5070
5071# Check that these methods don't raise errors
5072module.__repr__()
5073str(module)
5074
5075target = self._get_target()
5076
5077if self.reference_fn is not None:
5078out = test_case._forward_criterion(
5079module, input, target, extra_args=self.extra_args
5080)
5081ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
5082expected_out = self.reference_fn(*ref_args)
5083test_case.assertEqual(out, expected_out)
5084
5085if self.check_forward_only:
5086return
5087
5088test_case.check_criterion_jacobian(module, input, target)
5089self._do_extra_tests(test_case, module, input, target)
5090
5091def test_cuda(self, test_case):
5092if not TEST_CUDA or not self.should_test_cuda:
5093raise unittest.SkipTest("Excluded from CUDA tests")
5094try:
5095cpu_input = self._get_input()
5096type_map = {
5097"torch.DoubleTensor": torch.cuda.FloatTensor,
5098}
5099gpu_input = to_gpu(cpu_input, type_map=type_map)
5100
5101cpu_target = self._get_target()
5102gpu_target = to_gpu(cpu_target, type_map=type_map)
5103
5104cpu_module = self.constructor(*self.constructor_args)
5105gpu_module = self.constructor(*self.constructor_args).float().cuda()
5106
5107cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
5108gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
5109test_case.assertEqual(cpu_output, gpu_output, 4e-4)
5110
5111gradOutput = torch.randn(())
5112cpu_gradInput = test_case._backward_criterion(
5113cpu_module, cpu_input, cpu_target, gradOutput
5114)
5115gpu_gradInput = test_case._backward_criterion(
5116gpu_module, gpu_input, gpu_target, gradOutput
5117)
5118test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
5119except NotImplementedError:
5120pass
5121
5122def _do_extra_tests(self, test_case, module, input, target):
5123pass
5124
5125
5126class InputVariableMixin(object):
5127def _get_input(self):
5128input = TestBase._get_input(self, False)
5129
5130def map_variables(i):
5131if isinstance(i, torch.Tensor):
5132if i.is_floating_point():
5133i.requires_grad = True
5134return i
5135else:
5136return type(i)(map_variables(elem) for elem in i)
5137
5138return map_variables(input)
5139
5140
5141class NewModuleTest(InputVariableMixin, ModuleTest):
5142def __init__(self, *args, **kwargs):
5143super(NewModuleTest, self).__init__(*args, **kwargs)
5144self.cudnn = kwargs.get("cudnn", False)
5145self.check_inplace = kwargs.get("check_inplace", False)
5146self.check_gradgrad = kwargs.get("check_gradgrad", True)
5147self.skip_double = kwargs.get("skip_double", False)
5148
5149def _do_test(self, test_case, module, input):
5150test_case.check_jacobian(module, input, self.jacobian_input)
5151
5152if self.check_gradgrad:
5153# could probably unify check_jacobian above with this.
5154params = tuple(x for x in module.parameters())
5155_assertGradAndGradgradChecks(
5156test_case,
5157lambda x, *args, **kw: test_case._forward(module, x),
5158(input,) + params,
5159)
5160
5161# check if module can be printed
5162module.__repr__()
5163
5164if self.check_inplace:
5165# check if the inplace variant of the module gives the same result
5166# as the out-of-place
5167
5168module_ip = self.constructor(*self.constructor_args, inplace=True)
5169
5170input_version = input._version
5171with freeze_rng_state():
5172output = module(input)
5173test_case.assertEqual(input._version, input_version)
5174
5175input_ip = deepcopy(input)
5176if input.device.type == "xpu":
5177input_ip.requires_grad = True
5178input_ip_clone = input_ip.clone()
5179with freeze_rng_state():
5180output_ip = module_ip(input_ip_clone)
5181if input.device == torch.device("cpu"):
5182test_case.assertNotEqual(input_ip_clone._version, input_version)
5183test_case.assertEqual(output, output_ip)
5184grad = output.data.clone().normal_()
5185input.grad.data.zero_()
5186output.backward(grad)
5187output_ip.backward(grad)
5188test_case.assertEqual(input.grad, input_ip.grad)
5189
5190if isinstance(input, torch.LongTensor) and TEST_CUDA:
5191# check that cuda() moves module parameters to correct GPU device,
5192# and that float() casts parameters correctly
5193
5194input = input.cuda()
5195module.float().cuda()
5196module(input)
5197for p in module.parameters():
5198test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5199test_case.assertEqual(p.get_device(), 0)
5200
5201if torch.cuda.device_count() > 1:
5202input = input.cuda(1)
5203module.cuda(1)
5204with torch.cuda.device(1):
5205module(input)
5206for p in module.parameters():
5207test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5208test_case.assertEqual(p.get_device(), 1)
5209else:
5210# check that float()/double() casters work correctly
5211
5212# to float
5213if input.device == torch.device("cpu"):
5214if not isinstance(input, torch.LongTensor):
5215input = input.float()
5216module.float()
5217module(input)
5218for p in module.parameters():
5219test_case.assertIsInstance(p, torch.FloatTensor)
5220
5221# and back to double
5222if not isinstance(input, torch.LongTensor):
5223input = input.double()
5224module.double()
5225module(input)
5226for p in module.parameters():
5227test_case.assertIsInstance(p, torch.DoubleTensor)
5228# else: # for xpu
5229# print()
5230# if not isinstance(input, torch.xpu.LongTensor):
5231# input = input.float()
5232# module.float()
5233# module(input)
5234# for p in module.parameters():
5235# test_case.assertIsInstance(p, torch.xpu.FloatTensor)
5236
5237# # and back to double
5238# if not isinstance(input, torch.xpu.LongTensor):
5239# input = input.double()
5240# module.double()
5241# module(input)
5242# for p in module.parameters():
5243# test_case.assertIsInstance(p, torch.xpu.DoubleTensor)
5244
5245if TEST_CUDA and self.should_test_cuda:
5246# check that cuda() moves module parameters to correct GPU device,
5247# and that float() casts parameters correctly
5248
5249# to GPU0
5250input = input.float().cuda()
5251module.float().cuda()
5252module(input)
5253for p in module.parameters():
5254test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5255test_case.assertEqual(p.get_device(), 0)
5256
5257# to CPU
5258input = input.cpu()
5259module.cpu()
5260module(input)
5261for p in module.parameters():
5262test_case.assertIsInstance(p, torch.FloatTensor)
5263
5264# back to GPU0
5265input = input.cuda()
5266module.cuda()
5267module(input)
5268for p in module.parameters():
5269test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5270test_case.assertEqual(p.get_device(), 0)
5271
5272# test that forwards of module runs correctly without cuDNN
5273if self.cudnn:
5274with torch.backends.cudnn.flags(enabled=False):
5275module(input)
5276for p in module.parameters():
5277test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5278test_case.assertEqual(p.get_device(), 0)
5279
5280if torch.cuda.device_count() >= 2:
5281# test cross-GPU transfer works
5282# to GPU1
5283input = input.cuda(1)
5284module.cuda(1)
5285with torch.cuda.device(1):
5286module(input)
5287for p in module.parameters():
5288test_case.assertIsInstance(p, torch.cuda.FloatTensor)
5289test_case.assertEqual(p.get_device(), 1)
5290
5291if not self.skip_double:
5292# test double()
5293input = input.double().cuda()
5294module.double().cuda()
5295module(input)
5296for p in module.parameters():
5297test_case.assertIsInstance(p, torch.cuda.DoubleTensor)
5298test_case.assertEqual(p.get_device(), 0)
5299
5300# test half()
5301input = input.half().cuda()
5302module.half().cuda()
5303module(input)
5304for p in module.parameters():
5305test_case.assertIsInstance(p, torch.cuda.HalfTensor)
5306test_case.assertEqual(p.get_device(), 0)
5307
5308def _get_target(self):
5309return self._get_arg("target", False)
5310
5311@property
5312def constructor_args(self):
5313return self._get_arg("constructor_args", False)
5314
5315
5316class NewCriterionTest(InputVariableMixin, CriterionTest):
5317# TODO: check that criterions don't ignore grad_output
5318
5319def __init__(self, *args, **kwargs):
5320super(NewCriterionTest, self).__init__(*args, **kwargs)
5321self.check_gradgrad = kwargs.get("check_gradgrad", True)
5322self.check_half = kwargs.get("check_half", True)
5323self.check_bfloat16 = kwargs.get("check_bfloat16", False)
5324self.convert_target = kwargs.get("convert_target", True)
5325
5326def _do_extra_tests(self, test_case, module, input, target):
5327if not self.check_gradgrad:
5328return
5329
5330test_case.assertFalse(target.requires_grad)
5331
5332params = tuple(x for x in module.parameters())
5333if not isinstance(input, tuple):
5334inputs = (input,) + params
5335
5336def apply_fn(input, *params):
5337return module(input, target)
5338
5339else:
5340inputs = input + params
5341
5342def apply_fn(input1, input2, *params):
5343return module(input1, input2, target)
5344
5345# TODO: we don't pass `target` as part of inputs because we don't
5346# currently compute the gradient w.r.t. target for loss functions.
5347gradcheck(apply_fn, inputs)
5348gradgradcheck(apply_fn, inputs)
5349
5350def test_cuda(self, test_case, dtype=None, extra_args=None):
5351def convert_dtype(obj, dtype, requires_grad=False):
5352if isinstance(obj, torch.Tensor):
5353return obj.detach().to(dtype=dtype).requires_grad_(requires_grad)
5354elif isinstance(obj, torch.Tensor):
5355return obj.to(dtype)
5356elif isinstance(obj, tuple):
5357return tuple(convert_dtype(o, dtype, requires_grad) for o in obj)
5358else:
5359return obj
5360
5361if not TEST_CUDA or not self.should_test_cuda:
5362raise unittest.SkipTest("Excluded from CUDA tests")
5363try:
5364cpu_input = self._get_input()
5365cpu_target = self._get_target()
5366cpu_module = self.constructor(*self.constructor_args)
5367gpu_module = self.constructor(*self.constructor_args)
5368
5369# Convert input, target and module parameters to dtype
5370if dtype is not None:
5371cpu_input = convert_dtype(cpu_input, dtype, True)
5372# NLLLoss requires target to be LongTensor
5373if not isinstance(cpu_target, torch.LongTensor) and self.convert_target:
5374cpu_target = convert_dtype(cpu_target, dtype)
5375cpu_module.type(dtype)
5376gpu_module.type(dtype)
5377
5378# GPU setup
5379gpu_input = to_gpu(cpu_input)
5380gpu_target = to_gpu(cpu_target)
5381gpu_module.cuda()
5382
5383# torch.HalfTensor doesn't support most operations, converting back to default
5384if dtype in {torch.half, torch.bfloat16}:
5385cpu_input = self._get_input()
5386cpu_target = self._get_target()
5387# Loss modules with weights require consistent input/module weight types
5388cpu_module = self.constructor(*self.constructor_args)
5389
5390cpu_output = test_case._forward_criterion(
5391cpu_module, cpu_input, cpu_target, extra_args=extra_args
5392)
5393gpu_output = test_case._forward_criterion(
5394gpu_module, gpu_input, gpu_target, extra_args=extra_args
5395)
5396# dtype can be None, so set precision in this way instead of a precision map
5397test_case.assertEqual(
5398cpu_output,
5399gpu_output,
54001e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4,
5401)
5402
5403cpu_gradInput = test_case._backward_criterion(
5404cpu_module, cpu_input, cpu_target, extra_args=extra_args
5405)
5406gpu_gradInput = test_case._backward_criterion(
5407gpu_module, gpu_input, gpu_target, extra_args=extra_args
5408)
5409test_case.assertEqual(
5410cpu_gradInput,
5411gpu_gradInput,
54121e-1 if dtype in {torch.half, torch.bfloat16} else 4e-4,
5413)
5414except NotImplementedError:
5415pass
5416
5417def _get_target(self):
5418return self._get_arg("target", False)
5419
5420@property
5421def constructor_args(self):
5422return self._get_arg("constructor_args", False)
5423
5424@property
5425def extra_args(self):
5426return self._get_arg("extra_args", False)
5427