pytorch

Форк
0
/
hypothesis_test_util.py 
751 строка · 26.2 Кб
1
## @package hypothesis_test_util
2
# Module caffe2.python.hypothesis_test_util
3
"""
4
The Hypothesis library uses *property-based testing* to check
5
invariants about the code under test under a variety of random inputs.
6

7
 The key idea here is to express properties of the code under test
8
(e.g. that it passes a gradient check, that it implements a reference
9
function, etc), and then generate random instances and verify they
10
satisfy these properties.
11

12
The main functions of interest are exposed on `HypothesisTestCase`.
13
You can usually just add a short function in this to generate an
14
arbitrary number of test cases for your operator.
15

16
The key functions are:
17

18
- `assertDeviceChecks(devices, op, inputs, outputs)`. This asserts that the
19
  operator computes the same outputs, regardless of which device it is executed
20
  on.
21
- `assertGradientChecks(device, op, inputs, output_,
22
  outputs_with_grads)`. This implements a standard numerical gradient checker
23
  for the operator in question.
24
- `assertReferenceChecks(device, op, inputs, reference)`. This runs the
25
  reference function (effectively calling `reference(*inputs)`, and comparing
26
  that to the output of output.
27

28
`hypothesis_test_util.py` exposes some useful pre-built samplers.
29

30
- `hu.gcs` - a gradient checker device (`gc`) and device checker devices (`dc`)
31

32
- `hu.gcs_cpu_only` - a CPU-only gradient checker device (`gc`) and
33
  device checker devices (`dc`). Used for when your operator is only
34
  implemented on the CPU.
35
"""
36

37

38

39

40

41
from caffe2.proto import caffe2_pb2
42
from caffe2.python import (
43
    workspace, device_checker, gradient_checker, test_util, core)
44
import contextlib
45
import copy
46
import functools
47
import hypothesis
48
import hypothesis.extra.numpy
49
import hypothesis.strategies as st
50
import logging
51
import numpy as np
52
import os
53
import struct
54

55

56
def is_sandcastle():
57
    return os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
58

59

60
def is_travis():
61
    return 'TRAVIS' in os.environ
62

63

64
def to_float32(x):
65
    return struct.unpack("f", struct.pack("f", float(x)))[0]
66

67

68
#  "min_satisfying_examples" setting has been deprecated in hypothesis
69
#  3.56.0 and removed in hypothesis 4.x
70
def settings(*args, **kwargs):
71
    if 'min_satisfying_examples' in kwargs and hypothesis.version.__version_info__ >= (3, 56, 0):
72
        kwargs.pop('min_satisfying_examples')
73

74
    if 'deadline' in kwargs and hypothesis.version.__version_info__ < (4, 44, 0):
75
        kwargs.pop('deadline')
76

77
    if 'timeout' in kwargs and hypothesis.version.__version_info__ >= (4, 44, 0):
78
        if 'deadline' not in kwargs:
79
            kwargs['deadline'] = kwargs['timeout'] * 1e3
80
        kwargs.pop('timeout')
81

82
    return hypothesis.settings(*args, **kwargs)
83

84
# This wrapper wraps around `st.floats` and
85
# sets width parameters to 32 if version is newer than 3.67.0
86
def floats(*args, **kwargs):
87

88
    width_supported = hypothesis.version.__version_info__ >= (3, 67, 0)
89
    if 'width' in kwargs and not width_supported:
90
        kwargs.pop('width')
91

92
    if 'width' not in kwargs and width_supported:
93
        kwargs['width'] = 32
94
        if kwargs.get('min_value', None) is not None:
95
            kwargs['min_value'] = to_float32(kwargs['min_value'])
96
        if kwargs.get('max_value', None) is not None:
97
            kwargs['max_value'] = to_float32(kwargs['max_value'])
98

99
    return st.floats(*args, **kwargs)
100

101

102
hypothesis.settings.register_profile(
103
    "sandcastle",
104
    settings(
105
        derandomize=True,
106
        suppress_health_check=[hypothesis.HealthCheck.too_slow],
107
        database=None,
108
        max_examples=50,
109
        min_satisfying_examples=1,
110
        verbosity=hypothesis.Verbosity.verbose,
111
        deadline=10000))
112
hypothesis.settings.register_profile(
113
    "dev",
114
    settings(
115
        suppress_health_check=[hypothesis.HealthCheck.too_slow],
116
        database=None,
117
        max_examples=10,
118
        min_satisfying_examples=1,
119
        verbosity=hypothesis.Verbosity.verbose,
120
        deadline=10000))
121
hypothesis.settings.register_profile(
122
    "debug",
123
    settings(
124
        suppress_health_check=[hypothesis.HealthCheck.too_slow],
125
        database=None,
126
        max_examples=1000,
127
        min_satisfying_examples=1,
128
        verbosity=hypothesis.Verbosity.verbose,
129
        deadline=50000))
130

131
hypothesis.settings.load_profile(
132
    'sandcastle' if is_sandcastle() else os.getenv('CAFFE2_HYPOTHESIS_PROFILE',
133
                                                   'dev')
134
)
135

136

137
def dims(min_value=1, max_value=5):
138
    return st.integers(min_value=min_value, max_value=max_value)
139

140

141
def elements_of_type(dtype=np.float32, filter_=None):
142
    elems = None
143
    if dtype is np.float16:
144
        elems = floats(min_value=-1.0, max_value=1.0, width=16)
145
    elif dtype is np.float32:
146
        elems = floats(min_value=-1.0, max_value=1.0, width=32)
147
    elif dtype is np.float64:
148
        elems = floats(min_value=-1.0, max_value=1.0, width=64)
149
    elif dtype is np.int32:
150
        elems = st.integers(min_value=0, max_value=2 ** 31 - 1)
151
    elif dtype is np.int64:
152
        elems = st.integers(min_value=0, max_value=2 ** 63 - 1)
153
    elif dtype is bool:
154
        elems = st.booleans()
155
    else:
156
        raise ValueError("Unexpected dtype without elements provided")
157
    return elems if filter_ is None else elems.filter(filter_)
158

159

160
def arrays(dims, dtype=np.float32, elements=None, unique=False):
161
    if elements is None:
162
        elements = elements_of_type(dtype)
163
    return hypothesis.extra.numpy.arrays(
164
        dtype,
165
        dims,
166
        elements=elements,
167
        unique=unique,
168
    )
169

170

171
def tensor(min_dim=1,
172
           max_dim=4,
173
           dtype=np.float32,
174
           elements=None,
175
           unique=False,
176
           **kwargs):
177
    dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
178
    return dims_.flatmap(
179
        lambda dims: arrays(dims, dtype, elements, unique=unique))
180

181

182
def tensor1d(min_len=1, max_len=64, dtype=np.float32, elements=None):
183
    return tensor(1, 1, dtype, elements, min_value=min_len, max_value=max_len)
184

185

186
def segment_ids(size, is_sorted):
187
    if size == 0:
188
        return st.just(np.empty(shape=[0], dtype=np.int32))
189
    if is_sorted:
190
        return arrays(
191
            [size],
192
            dtype=np.int32,
193
            elements=st.booleans()).map(
194
                lambda x: np.cumsum(x, dtype=np.int32) - x[0])
195
    else:
196
        return arrays(
197
            [size],
198
            dtype=np.int32,
199
            elements=st.integers(min_value=0, max_value=2 * size))
200

201

202
def lengths(size, min_segments=None, max_segments=None, **kwargs):
203
    # First generate number of boarders between segments
204
    # Then create boarder values and add 0 and size
205
    # By sorting and computing diff we convert them to lengths of
206
    # possible 0 value
207
    if min_segments is None:
208
        min_segments = 0
209
    if max_segments is None:
210
        max_segments = size
211
    assert min_segments >= 0
212
    assert min_segments <= max_segments
213
    if size == 0 and max_segments == 0:
214
        return st.just(np.empty(shape=[0], dtype=np.int32))
215
    assert max_segments > 0, "size is not 0, need at least one segment"
216
    return st.integers(
217
        min_value=max(min_segments - 1, 0), max_value=max_segments - 1
218
    ).flatmap(
219
        lambda num_borders:
220
        hypothesis.extra.numpy.arrays(
221
            np.int32, num_borders, elements=st.integers(
222
                min_value=0, max_value=size
223
            )
224
        )
225
    ).map(
226
        lambda x: np.append(x, np.array([0, size], dtype=np.int32))
227
    ).map(sorted).map(np.diff)
228

229

230
def segmented_tensor(
231
    min_dim=1,
232
    max_dim=4,
233
    dtype=np.float32,
234
    is_sorted=True,
235
    elements=None,
236
    segment_generator=segment_ids,
237
    allow_empty=False,
238
    **kwargs
239
):
240
    gen_empty = st.booleans() if allow_empty else st.just(False)
241
    data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
242
    data_dims_ = st.tuples(
243
        gen_empty, data_dims_
244
    ).map(lambda pair: ([0] if pair[0] else []) + pair[1])
245
    return data_dims_.flatmap(lambda data_dims: st.tuples(
246
        arrays(data_dims, dtype, elements),
247
        segment_generator(data_dims[0], is_sorted=is_sorted),
248
    ))
249

250

251
def lengths_tensor(min_segments=None, max_segments=None, *args, **kwargs):
252
    gen = functools.partial(
253
        lengths, min_segments=min_segments, max_segments=max_segments)
254
    return segmented_tensor(*args, segment_generator=gen, **kwargs)
255

256

257
def sparse_segmented_tensor(min_dim=1, max_dim=4, dtype=np.float32,
258
                            is_sorted=True, elements=None, allow_empty=False,
259
                            segment_generator=segment_ids, itype=np.int64,
260
                            **kwargs):
261
    gen_empty = st.booleans() if allow_empty else st.just(False)
262
    data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
263
    all_dims_ = st.tuples(gen_empty, data_dims_).flatmap(
264
        lambda pair: st.tuples(
265
            st.just(pair[1]),
266
            (st.integers(min_value=1, max_value=pair[1][0]) if not pair[0]
267
             else st.just(0)),
268
        ))
269
    return all_dims_.flatmap(lambda dims: st.tuples(
270
        arrays(dims[0], dtype, elements),
271
        arrays(dims[1], dtype=itype, elements=st.integers(
272
            min_value=0, max_value=dims[0][0] - 1)),
273
        segment_generator(dims[1], is_sorted=is_sorted),
274
    ))
275

276

277
def sparse_lengths_tensor(**kwargs):
278
    return sparse_segmented_tensor(segment_generator=lengths, **kwargs)
279

280

281
def tensors(n, min_dim=1, max_dim=4, dtype=np.float32, elements=None, **kwargs):
282
    dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
283
    return dims_.flatmap(
284
        lambda dims: st.lists(
285
            arrays(dims, dtype, elements),
286
            min_size=n,
287
            max_size=n))
288

289

290
def tensors1d(n, min_len=1, max_len=64, dtype=np.float32, elements=None):
291
    return tensors(
292
        n, 1, 1, dtype, elements, min_value=min_len, max_value=max_len
293
    )
294

295

296
cpu_do = caffe2_pb2.DeviceOption()
297
cuda_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA)
298
hip_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.HIP)
299
gpu_do = caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType)  # CUDA or ROCm
300
_cuda_do_list = ([cuda_do] if workspace.has_cuda_support else [])
301
_hip_do_list = ([hip_do] if workspace.has_hip_support else [])
302
_gpu_do_list = ([gpu_do] if workspace.has_gpu_support else [])
303
# (bddppq) Do not rely on this no_hip option! It's just used to
304
# temporarily skip some flaky tests on ROCM before it's getting more mature.
305
_device_options_no_hip = [cpu_do] + _cuda_do_list
306
device_options = _device_options_no_hip + _hip_do_list
307

308
# Include device option for each GPU
309
expanded_device_options = [cpu_do] + [
310
    caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType, device_id=i)
311
    for i in range(workspace.NumGpuDevices())]
312

313

314
def device_checker_device_options():
315
    return st.just(device_options)
316

317

318
def gradient_checker_device_option():
319
    return st.sampled_from(device_options)
320

321

322
gcs = dict(
323
    gc=gradient_checker_device_option(),
324
    dc=device_checker_device_options()
325
)
326

327
gcs_cpu_only = dict(gc=st.sampled_from([cpu_do]), dc=st.just([cpu_do]))
328
gcs_cuda_only = dict(gc=st.sampled_from(_cuda_do_list), dc=st.just(_cuda_do_list))
329
gcs_gpu_only = dict(gc=st.sampled_from(_gpu_do_list), dc=st.just(_gpu_do_list))  # CUDA or ROCm
330
gcs_no_hip = dict(gc=st.sampled_from(_device_options_no_hip), dc=st.just(_device_options_no_hip))
331

332

333
@contextlib.contextmanager
334
def temp_workspace(name=b"temp_ws"):
335
    old_ws_name = workspace.CurrentWorkspace()
336
    workspace.SwitchWorkspace(name, True)
337
    yield
338
    workspace.ResetWorkspace()
339
    workspace.SwitchWorkspace(old_ws_name)
340

341

342
def runOpBenchmark(
343
    device_option,
344
    op,
345
    inputs,
346
    input_device_options=None,
347
    iterations=10,
348
):
349
    op = copy.deepcopy(op)
350
    op.device_option.CopyFrom(device_option)
351
    net = caffe2_pb2.NetDef()
352
    net.op.extend([op])
353
    net.name = op.name if op.name else "test"
354

355
    with temp_workspace():
356
        _input_device_options = input_device_options or \
357
            core.InferOpBlobDevicesAsDict(op)[0]
358
        for (n, b) in zip(op.input, inputs):
359
            workspace.FeedBlob(
360
                n,
361
                b,
362
                device_option=_input_device_options.get(n, device_option)
363
            )
364
        workspace.CreateNet(net)
365
        ret = workspace.BenchmarkNet(net.name, 1, iterations, True)
366
    return ret
367

368

369
def runOpOnInput(
370
    device_option,
371
    op,
372
    inputs,
373
    input_device_options=None,
374
):
375
    op = copy.deepcopy(op)
376
    op.device_option.CopyFrom(device_option)
377

378
    with temp_workspace():
379
        if (len(op.input) > len(inputs)):
380
            raise ValueError(
381
                'must supply an input for each input on the op: %s vs %s' %
382
                (op.input, inputs))
383
        _input_device_options = input_device_options or \
384
            core.InferOpBlobDevicesAsDict(op)[0]
385
        for (n, b) in zip(op.input, inputs):
386
            workspace.FeedBlob(
387
                n,
388
                b,
389
                device_option=_input_device_options.get(n, device_option)
390
            )
391
        workspace.RunOperatorOnce(op)
392
        outputs_to_check = list(range(len(op.output)))
393
        outs = []
394
        for output_index in outputs_to_check:
395
            output_blob_name = op.output[output_index]
396
            output = workspace.FetchBlob(output_blob_name)
397
            outs.append(output)
398
        return outs
399

400

401
class HypothesisTestCase(test_util.TestCase):
402
    """
403
    A unittest.TestCase subclass with some helper functions for
404
    utilizing the `hypothesis` (hypothesis.readthedocs.io) library.
405
    """
406

407
    def assertDeviceChecks(
408
        self,
409
        device_options,
410
        op,
411
        inputs,
412
        outputs_to_check,
413
        input_device_options=None,
414
        threshold=0.01
415
    ):
416
        """
417
        Asserts that the operator computes the same outputs, regardless of
418
        which device it is executed on.
419

420
        Useful for checking the consistency of GPU and CPU
421
        implementations of operators.
422

423
        Usage example:
424

425
            @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
426
            def test_sum(self, inputs, in_place, gc, dc):
427
                op = core.CreateOperator("Sum", ["X1", "X2"],
428
                                                ["Y" if not in_place else "X1"])
429
                X1, X2 = inputs
430
                self.assertDeviceChecks(dc, op, [X1, X2], [0])
431
        """
432
        dc = device_checker.DeviceChecker(
433
            threshold,
434
            device_options=device_options
435
        )
436
        self.assertTrue(
437
            dc.CheckSimple(op, inputs, outputs_to_check, input_device_options)
438
        )
439

440
    def assertGradientChecks(
441
        self,
442
        device_option,
443
        op,
444
        inputs,
445
        outputs_to_check,
446
        outputs_with_grads,
447
        grad_ops=None,
448
        threshold=0.005,
449
        stepsize=0.05,
450
        input_device_options=None,
451
        ensure_outputs_are_inferred=False,
452
    ):
453
        """
454
        Implements a standard numerical gradient checker for the operator
455
        in question.
456

457
        Useful for checking the consistency of the forward and
458
        backward implementations of operators.
459

460
        Usage example:
461

462
            @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
463
            def test_sum(self, inputs, in_place, gc, dc):
464
                op = core.CreateOperator("Sum", ["X1", "X2"],
465
                                                ["Y" if not in_place else "X1"])
466
                X1, X2 = inputs
467
                self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
468
        """
469
        gc = gradient_checker.GradientChecker(
470
            stepsize=stepsize,
471
            threshold=threshold,
472
            device_option=device_option,
473
            workspace_name=str(device_option),
474
            input_device_options=input_device_options,
475
        )
476
        res, grad, grad_estimated = gc.CheckSimple(
477
            op, inputs, outputs_to_check, outputs_with_grads,
478
            grad_ops=grad_ops,
479
            input_device_options=input_device_options,
480
            ensure_outputs_are_inferred=ensure_outputs_are_inferred,
481
        )
482
        self.assertEqual(grad.shape, grad_estimated.shape)
483
        self.assertTrue(
484
            res,
485
            "Gradient check failed for input " + str(op.input[outputs_to_check])
486
        )
487

488
    def _assertGradReferenceChecks(
489
        self,
490
        op,
491
        inputs,
492
        ref_outputs,
493
        output_to_grad,
494
        grad_reference,
495
        threshold=1e-4,
496
    ):
497
        grad_blob_name = output_to_grad + '_grad'
498
        grad_ops, grad_map = core.GradientRegistry.GetBackwardPass(
499
            [op], {output_to_grad: grad_blob_name})
500
        output_grad = workspace.FetchBlob(output_to_grad)
501
        grad_ref_outputs = grad_reference(output_grad, ref_outputs, inputs)
502
        workspace.FeedBlob(grad_blob_name, workspace.FetchBlob(output_to_grad))
503
        workspace.RunOperatorsOnce(grad_ops)
504

505
        self.assertEqual(len(grad_ref_outputs), len(inputs))
506
        for (n, ref) in zip(op.input, grad_ref_outputs):
507
            grad_names = grad_map.get(n)
508
            if not grad_names:
509
                # no grad for this input
510
                self.assertIsNone(ref)
511
            else:
512
                if isinstance(grad_names, core.BlobReference):
513
                    # dense gradient
514
                    ref_vals = ref
515
                    ref_indices = None
516
                    val_name = grad_names
517
                else:
518
                    # sparse gradient
519
                    ref_vals, ref_indices = ref
520
                    val_name = grad_names.values
521
                vals = workspace.FetchBlob(str(val_name))
522
                np.testing.assert_allclose(
523
                    vals,
524
                    ref_vals,
525
                    atol=threshold,
526
                    rtol=threshold,
527
                    err_msg='Gradient {0} (x) is not matching the reference (y)'
528
                    .format(val_name),
529
                )
530
                if ref_indices is not None:
531
                    indices = workspace.FetchBlob(str(grad_names.indices))
532
                    np.testing.assert_allclose(indices, ref_indices,
533
                                               atol=1e-4, rtol=1e-4)
534

535
    def _assertInferTensorChecks(self, name, shapes, types, output,
536
                                 ensure_output_is_inferred=False):
537
        self.assertTrue(
538
            not ensure_output_is_inferred or (name in shapes),
539
            'Shape for {0} was not inferred'.format(name))
540

541
        if name not in shapes:
542
            # No inferred shape or type available
543
            return
544
        output = workspace.FetchBlob(name)
545
        if type(output) is np.ndarray:
546
            if output.dtype == np.dtype('float64'):
547
                correct_type = caffe2_pb2.TensorProto.DOUBLE
548
            elif output.dtype == np.dtype('float32'):
549
                correct_type = caffe2_pb2.TensorProto.FLOAT
550
            elif output.dtype == np.dtype('int32'):
551
                correct_type = caffe2_pb2.TensorProto.INT32
552
            elif output.dtype == np.dtype('int64'):
553
                correct_type = caffe2_pb2.TensorProto.INT64
554
            else:
555
                correct_type = "unknown {}".format(np.dtype)
556
        else:
557
            correct_type = str(type(output))
558
        try:
559
            np.testing.assert_array_equal(
560
                np.array(shapes[name]).astype(np.int32),
561
                np.array(output.shape).astype(np.int32),
562
                err_msg='Shape {} mismatch: {} vs. {}'.format(
563
                    name,
564
                    shapes[name],
565
                    output.shape))
566
            # BUG: Workspace blob type not being set correctly T16121392
567
            if correct_type != caffe2_pb2.TensorProto.INT32:
568
                return
569
            np.testing.assert_equal(
570
                types[name],
571
                correct_type,
572
                err_msg='Type {} mismatch: {} vs. {}'.format(
573
                    name, types[name], correct_type,
574
                )
575
            )
576
        except AssertionError as e:
577
            # Temporarily catch these assertion errors when validating
578
            # inferred shape and type info
579
            logging.warning(str(e))
580
            if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1' or ensure_output_is_inferred:
581
                raise e
582

583
    def assertReferenceChecks(
584
        self,
585
        device_option,
586
        op,
587
        inputs,
588
        reference,
589
        input_device_options=None,
590
        threshold=1e-4,
591
        output_to_grad=None,
592
        grad_reference=None,
593
        atol=None,
594
        outputs_to_check=None,
595
        ensure_outputs_are_inferred=False,
596
    ):
597
        """
598
        This runs the reference Python function implementation
599
        (effectively calling `reference(*inputs)`, and compares that
600
        to the output of output, with an absolute/relative tolerance
601
        given by the `threshold` parameter.
602

603
        Useful for checking the implementation matches the Python
604
        (typically NumPy) implementation of the same functionality.
605

606
        Usage example:
607

608
            @given(X=hu.tensor(), inplace=st.booleans(), **hu.gcs)
609
            def test_softsign(self, X, inplace, gc, dc):
610
                op = core.CreateOperator(
611
                    "Softsign", ["X"], ["X" if inplace else "Y"])
612

613
                def softsign(X):
614
                    return (X / (1 + np.abs(X)),)
615

616
                self.assertReferenceChecks(gc, op, [X], softsign)
617
        """
618
        op = copy.deepcopy(op)
619
        op.device_option.CopyFrom(device_option)
620

621
        with temp_workspace():
622
            if (len(op.input) > len(inputs)):
623
                raise ValueError(
624
                    'must supply an input for each input on the op: %s vs %s' %
625
                    (op.input, inputs))
626
            _input_device_options = input_device_options or \
627
                core.InferOpBlobDevicesAsDict(op)[0]
628
            for (n, b) in zip(op.input, inputs):
629
                workspace.FeedBlob(
630
                    n,
631
                    b,
632
                    device_option=_input_device_options.get(n, device_option)
633
                )
634
            net = core.Net("opnet")
635
            net.Proto().op.extend([op])
636
            test_shape_inference = False
637
            try:
638
                (shapes, types) = workspace.InferShapesAndTypes([net])
639
                test_shape_inference = True
640
            except RuntimeError as e:
641
                # Temporarily catch runtime errors when inferring shape
642
                # and type info
643
                logging.warning(str(e))
644
                if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1' or ensure_outputs_are_inferred:
645
                    raise e
646
            workspace.RunNetOnce(net)
647
            reference_outputs = reference(*inputs)
648
            if not (isinstance(reference_outputs, tuple) or
649
                    isinstance(reference_outputs, list)):
650
                raise RuntimeError(
651
                    "You are providing a wrong reference implementation. A "
652
                    "proper one should return a tuple/list of numpy arrays.")
653
            if not outputs_to_check:
654
                self.assertEqual(len(reference_outputs), len(op.output))
655
                outputs_to_check = list(range(len(op.output)))
656
            outs = []
657
            for (output_index, ref) in zip(outputs_to_check, reference_outputs):
658
                output_blob_name = op.output[output_index]
659
                output = workspace.FetchBlob(output_blob_name)
660
                if output.dtype.kind in ('S', 'O'):
661
                    np.testing.assert_array_equal(output, ref)
662
                else:
663
                    if atol is None:
664
                        atol = threshold
665
                    np.testing.assert_allclose(
666
                        output, ref, atol=atol, rtol=threshold,
667
                        err_msg=(
668
                            'Output {0} is not matching the reference'.format(
669
                                output_blob_name,
670
                            )),
671
                    )
672
                if test_shape_inference:
673
                    self._assertInferTensorChecks(
674
                        output_blob_name, shapes, types, output,
675
                        ensure_output_is_inferred=ensure_outputs_are_inferred)
676
                outs.append(output)
677
            if grad_reference is not None:
678
                assert output_to_grad is not None, \
679
                    "If grad_reference is set," \
680
                    "output_to_grad has to be set as well"
681

682
                with core.DeviceScope(device_option):
683
                    self._assertGradReferenceChecks(
684
                        op, inputs, reference_outputs,
685
                        output_to_grad, grad_reference,
686
                        threshold=threshold)
687

688
            return outs
689

690
    def assertValidationChecks(
691
            self,
692
            device_option,
693
            op,
694
            inputs,
695
            validator,
696
            input_device_options=None,
697
            as_kwargs=True,
698
            init_net=None,
699
    ):
700
        if as_kwargs:
701
            assert len(set(list(op.input) + list(op.output))) == \
702
                len(op.input) + len(op.output), \
703
                "in-place ops are not supported in as_kwargs mode"
704
        op = copy.deepcopy(op)
705
        op.device_option.CopyFrom(device_option)
706

707
        with temp_workspace():
708
            _input_device_options = input_device_options or \
709
                core.InferOpBlobDevicesAsDict(op)[0]
710
            for (n, b) in zip(op.input, inputs):
711
                workspace.FeedBlob(
712
                    n,
713
                    b,
714
                    device_option=_input_device_options.get(n, device_option)
715
                )
716
            if init_net:
717
                workspace.RunNetOnce(init_net)
718
            workspace.RunOperatorOnce(op)
719
            outputs = [workspace.FetchBlob(n) for n in op.output]
720
            if as_kwargs:
721
                validator(**dict(zip(
722
                    list(op.input) + list(op.output), inputs + outputs)))
723
            else:
724
                validator(inputs=inputs, outputs=outputs)
725

726
    def assertRunOpRaises(
727
        self,
728
        device_option,
729
        op,
730
        inputs,
731
        input_device_options=None,
732
        exception=(Exception,),
733
        regexp=None,
734
    ):
735
        op = copy.deepcopy(op)
736
        op.device_option.CopyFrom(device_option)
737

738
        with temp_workspace():
739
            _input_device_options = input_device_options or \
740
                core.InferOpBlobDevicesAsDict(op)[0]
741
            for (n, b) in zip(op.input, inputs):
742
                workspace.FeedBlob(
743
                    n,
744
                    b,
745
                    device_option=_input_device_options.get(n, device_option)
746
                )
747
            if regexp is None:
748
                self.assertRaises(exception, workspace.RunOperatorOnce, op)
749
            else:
750
                self.assertRaisesRegex(
751
                    exception, regexp, workspace.RunOperatorOnce, op)
752

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.