pytorch
1007 строк · 31.7 Кб
1
2
3import collections
4import functools
5import unittest
6
7import caffe2.python._import_c_extension as C
8import caffe2.python.hip_test_util as hiputl
9import caffe2.python.hypothesis_test_util as hu
10import caffe2.python.serialized_test.serialized_test_util as serial
11import hypothesis.strategies as st
12import numpy as np
13from caffe2.proto import caffe2_pb2
14from caffe2.python import brew, core, utils, workspace
15from caffe2.python.model_helper import ModelHelper
16from hypothesis import assume, given, settings
17
18
19def _cudnn_supports(dilation=False, nhwc=False, backward=False):
20"""Return True if cuDNN supports this configuration."""
21v = workspace.GetCuDNNVersion()
22if backward:
23if nhwc:
24# nhwc isn't supported in backward ops.
25return False
26else:
27# Forward mode.
28if dilation and v < 6000:
29# Dilation not supported until v6
30return False
31if dilation and nhwc:
32# Dilation and NHWC not supported together
33return False
34return True
35
36
37def _cudnn_convolution_algo_count(direction):
38try:
39if direction == "fwd":
40return st.integers(0, C.cudnn_convolution_fwd_algo_count - 1)
41elif direction == "dgrad":
42return st.integers(0, C.cudnn_convolution_bwd_data_algo_count - 1)
43elif direction == "wgrad":
44return st.integers(0, C.cudnn_convolution_bwd_filter_algo_count - 1)
45else:
46assert False
47except Exception:
48return st.sampled_from([-1])
49
50
51class TestConvolution(serial.SerializedTestCase):
52# CUDNN does NOT support different padding values and we skip it
53@given(
54op_type=st.sampled_from(["Conv", "Conv2D"]),
55stride_h=st.integers(1, 3),
56stride_w=st.integers(1, 3),
57pad_t=st.integers(0, 3),
58pad_l=st.integers(0, 3),
59pad_b=st.integers(0, 3),
60pad_r=st.integers(0, 3),
61kernel=st.integers(3, 5),
62size=st.integers(1, 8),
63input_channels=st.integers(1, 3),
64output_channels=st.integers(1, 3),
65batch_size=st.integers(0, 3),
66group=st.integers(1, 2),
67order=st.sampled_from(["NCHW", "NHWC"]),
68engine=st.sampled_from(["", "EIGEN"]),
69shared_buffer=st.booleans(),
70use_bias=st.booleans(),
71**hu.gcs
72)
73@settings(deadline=None, max_examples=50)
74def test_convolution_separate_stride_pad_gradients(
75self,
76op_type,
77stride_h,
78stride_w,
79pad_t,
80pad_l,
81pad_b,
82pad_r,
83kernel,
84size,
85input_channels,
86output_channels,
87batch_size,
88group,
89order,
90engine,
91shared_buffer,
92use_bias,
93gc,
94dc,
95):
96# TODO: Group conv in NHWC not implemented for GPU yet.
97assume(group == 1 or order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
98if group != 1 and order == "NHWC":
99dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
100# Group conv not implemented with EIGEN engine.
101assume(group == 1 or engine != "EIGEN")
102
103input_channels *= group
104output_channels *= group
105
106op = core.CreateOperator(
107op_type,
108["X", "w", "b"] if use_bias else ["X", "w"],
109["Y"],
110stride_h=stride_h,
111stride_w=stride_w,
112pad_t=pad_t,
113pad_l=pad_l,
114pad_b=pad_b,
115pad_r=pad_r,
116kernel=kernel,
117group=group,
118order=order,
119engine=engine,
120shared_buffer=int(shared_buffer),
121)
122X = (
123np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
124- 0.5
125)
126w = (
127np.random.rand(
128output_channels, kernel, kernel, int(input_channels / group)
129).astype(np.float32)
130- 0.5
131)
132b = np.random.rand(output_channels).astype(np.float32) - 0.5
133if order == "NCHW":
134X = utils.NHWC2NCHW(X)
135w = utils.NHWC2NCHW(w)
136
137inputs = [X, w, b] if use_bias else [X, w]
138
139# Error handling path.
140if size + pad_r + pad_l < kernel or size + pad_t + pad_b < kernel:
141with self.assertRaises(RuntimeError):
142self.assertDeviceChecks(dc, op, inputs, [0])
143return
144
145self.assertDeviceChecks(dc, op, inputs, [0])
146for i in range(len(inputs)):
147self.assertGradientChecks(gc, op, inputs, i, [0])
148
149# CUDNN does NOT support different padding values and we skip it
150@given(
151op_type=st.sampled_from(["Conv", "Conv2D"]),
152stride_h=st.integers(1, 3),
153stride_w=st.integers(1, 3),
154pad_t=st.integers(0, 3),
155pad_l=st.integers(0, 3),
156pad_b=st.integers(0, 3),
157pad_r=st.integers(0, 3),
158kernel=st.integers(1, 5),
159size=st.integers(7, 10),
160input_channels=st.integers(1, 8),
161output_channels=st.integers(1, 8),
162batch_size=st.integers(0, 3),
163engine=st.sampled_from(["", "EIGEN"]),
164use_bias=st.booleans(),
165**hu.gcs
166)
167@settings(deadline=None)
168def test_convolution_separate_stride_pad_layout(
169self,
170op_type,
171stride_h,
172stride_w,
173pad_t,
174pad_l,
175pad_b,
176pad_r,
177kernel,
178size,
179input_channels,
180output_channels,
181batch_size,
182engine,
183use_bias,
184gc,
185dc,
186):
187X = (
188np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
189- 0.5
190)
191w = (
192np.random.rand(output_channels, kernel, kernel, input_channels).astype(
193np.float32
194)
195- 0.5
196)
197b = np.random.rand(output_channels).astype(np.float32) - 0.5
198outputs = {}
199for order in ["NCHW", "NHWC"]:
200op = core.CreateOperator(
201op_type,
202["X", "w", "b"] if use_bias else ["X", "w"],
203["Y"],
204stride_h=stride_h,
205stride_w=stride_w,
206kernel=kernel,
207pad_t=pad_t,
208pad_l=pad_l,
209pad_b=pad_b,
210pad_r=pad_r,
211order=order,
212engine=engine,
213device_option=gc,
214)
215if order == "NCHW":
216X_f = utils.NHWC2NCHW(X)
217w_f = utils.NHWC2NCHW(w)
218else:
219X_f = X
220w_f = w
221self.ws.create_blob("X").feed(X_f, device_option=gc)
222self.ws.create_blob("w").feed(w_f, device_option=gc)
223self.ws.create_blob("b").feed(b, device_option=gc)
224self.ws.run(op)
225outputs[order] = self.ws.blobs["Y"].fetch()
226np.testing.assert_allclose(
227outputs["NCHW"], utils.NHWC2NCHW(outputs["NHWC"]), atol=1e-4, rtol=1e-4
228)
229
230@given(
231op_type=st.sampled_from(["Conv", "Conv2D"]),
232stride=st.integers(1, 3),
233pad=st.integers(0, 3),
234kernel=st.integers(1, 5),
235dilation=st.integers(1, 3),
236size=st.integers(7, 10),
237input_channels=st.integers(1, 8),
238output_channels=st.integers(1, 8),
239batch_size=st.integers(0, 3),
240group=st.integers(1, 2),
241order=st.sampled_from(["NCHW", "NHWC"]),
242engine=st.sampled_from(["", "CUDNN", "MKLDNN"]),
243use_bias=st.booleans(),
244force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
245force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
246force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
247**hu.gcs
248)
249@settings(max_examples=20, deadline=None)
250def test_convolution_gradients(
251self,
252op_type,
253stride,
254pad,
255kernel,
256dilation,
257size,
258input_channels,
259output_channels,
260batch_size,
261group,
262order,
263engine,
264use_bias,
265force_algo_fwd,
266force_algo_dgrad,
267force_algo_wgrad,
268gc,
269dc,
270):
271# TODO: Group conv in NHWC not implemented for GPU yet.
272assume(
273group == 1
274or (order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
275and engine != "MKLDNN"
276)
277if group != 1 and order == "NHWC":
278dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
279
280input_channels *= group
281output_channels *= group
282dkernel = dilation * (kernel - 1) + 1
283
284if engine == "CUDNN":
285if hiputl.run_in_hip(gc, dc):
286assume((order == "NCHW") and not (dilation > 1 and group > 1))
287else:
288assume(
289_cudnn_supports(
290dilation=(dilation > 1), nhwc=(order == "NHWC"), backward=True
291)
292)
293
294assume(engine != "MKLDNN" or use_bias is True)
295
296op = core.CreateOperator(
297op_type,
298["X", "w", "b"] if use_bias else ["X", "w"],
299["Y"],
300stride=stride,
301kernel=kernel,
302dilation=dilation,
303pad=pad,
304group=group,
305order=order,
306engine=engine,
307force_algo_fwd=force_algo_fwd,
308force_algo_dgrad=force_algo_dgrad,
309force_algo_wgrad=force_algo_wgrad,
310)
311X = (
312np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
313- 0.5
314)
315w = (
316np.random.rand(
317output_channels, kernel, kernel, int(input_channels / group)
318).astype(np.float32)
319- 0.5
320)
321b = np.random.rand(output_channels).astype(np.float32) - 0.5
322if order == "NCHW":
323X = utils.NHWC2NCHW(X)
324w = utils.NHWC2NCHW(w)
325
326inputs = [X, w, b] if use_bias else [X, w]
327# Error handling path.
328if size + pad + pad < dkernel or size + pad + pad < dkernel:
329with self.assertRaises(RuntimeError):
330self.assertDeviceChecks(dc, op, inputs, [0])
331return
332
333try:
334self.assertDeviceChecks(dc, op, inputs, [0])
335except RuntimeError as e:
336es = str(e)
337# CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM should always have
338# implementation
339if (
340"status == CUDNN_STATUS_SUCCESS" not in es
341or "CUDNN_STATUS_NOT_SUPPORTED" not in es
342or force_algo_fwd == 0
343):
344raise e
345
346for i in range(len(inputs)):
347try:
348self.assertGradientChecks(gc, op, inputs, i, [0])
349except RuntimeError as e:
350es = str(e)
351if (
352"status == CUDNN_STATUS_SUCCESS" not in es
353or "CUDNN_STATUS_NOT_SUPPORTED" not in es
354):
355raise e
356
357def _nd_convolution(
358self,
359n,
360input_channels_per_group,
361output_channels_per_group,
362batch_size,
363stride,
364size,
365kernel,
366dilation,
367pad,
368group,
369order,
370use_bias,
371engine,
372force_algo_fwd,
373force_algo_dgrad,
374force_algo_wgrad,
375gc,
376dc,
377):
378# TODO: Group conv in NHWC not implemented for GPU yet.
379# TODO: Group 1D conv in NCHW not implemented for GPU yet.
380assume(
381group == 1
382or (n != 1 and order == "NCHW")
383or gc.device_type == caffe2_pb2.CPU
384)
385if group != 1 and (n == 1 or order == "NHWC"):
386dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
387
388input_channels = group * input_channels_per_group
389output_channels = group * output_channels_per_group
390
391dkernel = dilation * (kernel - 1) + 1
392for op_type in ["Conv", "Conv" + str(n) + "D"]:
393op = core.CreateOperator(
394op_type,
395["X", "w", "b"] if use_bias else ["X", "w"],
396["Y"],
397strides=[stride] * n,
398kernels=[kernel] * n,
399dilations=[dilation] * n,
400pads=[pad] * n * 2,
401group=group,
402order=order,
403engine=engine,
404force_algo_fwd=force_algo_fwd,
405force_algo_dgrad=force_algo_dgrad,
406force_algo_wgrad=force_algo_wgrad,
407)
408
409input_dims = [batch_size, input_channels]
410input_dims.extend([size] * n)
411filter_dims = [output_channels, input_channels // group]
412filter_dims.extend([kernel] * n)
413
414X = np.random.rand(*input_dims).astype(np.float32) - 0.5
415w = np.random.rand(*filter_dims).astype(np.float32) - 0.5
416b = np.random.rand(output_channels).astype(np.float32) - 0.5
417if order == "NHWC":
418X = utils.NCHW2NHWC(X)
419w = utils.NCHW2NHWC(w)
420
421inputs = [X, w, b] if use_bias else [X, w]
422
423if size + pad + pad < dkernel or size + pad + pad < dkernel:
424with self.assertRaises(RuntimeError):
425self.assertDeviceChecks(dc, op, inputs, [0])
426return
427
428self.assertDeviceChecks(dc, op, inputs, [0])
429for i in range(len(inputs)):
430self.assertGradientChecks(gc, op, inputs, i, [0])
431
432@given(
433input_channels=st.integers(1, 3),
434output_channels=st.integers(1, 2),
435batch_size=st.integers(0, 3),
436stride=st.integers(1, 3),
437size=st.integers(7, 10),
438kernel=st.integers(1, 2),
439dilation=st.integers(1, 3),
440pad=st.integers(0, 3),
441group=st.integers(1, 2),
442order=st.sampled_from(["NCHW", "NHWC"]),
443use_bias=st.booleans(),
444engine=st.sampled_from(["", "CUDNN"]),
445force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
446force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
447force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
448**hu.gcs
449)
450@settings(deadline=10000)
451def test_1d_convolution(
452self,
453input_channels,
454output_channels,
455batch_size,
456stride,
457size,
458kernel,
459dilation,
460pad,
461group,
462order,
463use_bias,
464engine,
465force_algo_fwd,
466force_algo_dgrad,
467force_algo_wgrad,
468gc,
469dc,
470):
471if hiputl.run_in_hip(gc, dc):
472# currently miopen only supports 2d conv
473assume(engine != "CUDNN") # CUDNN is aliased to MIOPEN for HIP
474# TODO: 1D conv in NHWC not implemented for GPU yet.
475assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
476if order == "NHWC":
477dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
478
479self._nd_convolution(
4801,
481input_channels,
482output_channels,
483batch_size,
484stride,
485size,
486kernel,
487dilation,
488pad,
489group,
490order,
491use_bias,
492engine,
493force_algo_fwd,
494force_algo_dgrad,
495force_algo_wgrad,
496gc,
497dc,
498)
499
500@given(
501input_channels=st.integers(1, 2),
502output_channels=st.integers(1, 2),
503batch_size=st.integers(0, 2),
504stride=st.integers(1, 2),
505size=st.integers(4, 5),
506kernel=st.integers(1, 2),
507dilation=st.integers(1, 2),
508pad=st.integers(0, 2),
509group=st.integers(1, 2),
510order=st.sampled_from(["NCHW", "NHWC"]),
511use_bias=st.booleans(),
512engine=st.sampled_from(["", "MIOPEN"]), # TODO: add "CUDNN"
513force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
514force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
515force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
516**hu.gcs
517)
518@settings(max_examples=20, deadline=None)
519def test_3d_convolution(
520self,
521input_channels,
522output_channels,
523batch_size,
524stride,
525size,
526kernel,
527dilation,
528pad,
529group,
530order,
531use_bias,
532engine,
533force_algo_fwd,
534force_algo_dgrad,
535force_algo_wgrad,
536gc,
537dc,
538):
539# TODO: 3D conv in NHWC not implemented for GPU yet.
540assume(order == "NCHW" or gc.device_type == caffe2_pb2.CPU)
541if order == "NHWC":
542dc = [d for d in dc if d.device_type == caffe2_pb2.CPU]
543self._nd_convolution(
5443,
545input_channels,
546output_channels,
547batch_size,
548stride,
549size,
550kernel,
551dilation,
552pad,
553group,
554order,
555use_bias,
556engine,
557force_algo_fwd,
558force_algo_dgrad,
559force_algo_wgrad,
560gc,
561dc,
562)
563
564@given(
565op_type=st.sampled_from(["Conv", "Conv3D"]),
566batch_size=st.integers(0, 2),
567stride=st.integers(1, 2),
568size=st.integers(3, 5),
569kernel=st.integers(1, 2),
570dilation=st.integers(1, 2),
571pad=st.integers(0, 2),
572use_bias=st.booleans(),
573force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
574force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
575force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
576**hu.gcs_no_hip
577) # MIOPEN doesn't support 3D conv yet
578@settings(deadline=10000)
579def test_3d_convolution_cudnn_nchw(
580self,
581op_type,
582batch_size,
583stride,
584size,
585kernel,
586dilation,
587pad,
588use_bias,
589force_algo_fwd,
590force_algo_dgrad,
591force_algo_wgrad,
592gc,
593dc,
594):
595input_channels = 1
596output_channels = 1
597n = 3
598dkernel = dilation * (kernel - 1) + 1
599order = "NCHW"
600
601op = core.CreateOperator(
602op_type,
603["X", "w", "b"] if use_bias else ["X", "w"],
604["Y"],
605strides=[stride] * n,
606kernels=[kernel] * n,
607dilations=[dilation] * n,
608pads=[pad] * n * 2,
609order=order,
610engine="CUDNN",
611force_algo_fwd=force_algo_fwd,
612force_algo_dgrad=force_algo_dgrad,
613force_algo_wgrad=force_algo_wgrad,
614)
615
616input_dims = [batch_size, input_channels]
617input_dims.extend([size] * n)
618filter_dims = [output_channels, input_channels]
619filter_dims.extend([kernel] * n)
620X = np.random.rand(*input_dims).astype(np.float32) - 0.5
621w = np.random.rand(*filter_dims).astype(np.float32) - 0.5
622b = np.random.rand(output_channels).astype(np.float32) - 0.5
623
624inputs = [X, w, b] if use_bias else [X, w]
625
626if size + pad + pad < dkernel or size + pad + pad < dkernel:
627with self.assertRaises(RuntimeError):
628self.assertDeviceChecks(dc, op, inputs, [0])
629return
630
631try:
632self.assertDeviceChecks(dc, op, inputs, [0])
633except RuntimeError as e:
634es = str(e)
635# CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM should always have
636# implementation
637if (
638"status == CUDNN_STATUS_SUCCESS" not in es
639or "CUDNN_STATUS_NOT_SUPPORTED" not in es
640or force_algo_fwd == 0
641):
642raise e
643
644for i in range(len(inputs)):
645try:
646self.assertGradientChecks(gc, op, inputs, i, [0])
647except RuntimeError as e:
648es = str(e)
649if (
650"status == CUDNN_STATUS_SUCCESS" not in es
651or "CUDNN_STATUS_NOT_SUPPORTED" not in es
652):
653raise e
654
655@given(
656op_type=st.sampled_from(["Conv", "Conv2D"]),
657stride=st.integers(1, 3),
658pad=st.integers(0, 3),
659kernel=st.integers(1, 5),
660dilation=st.integers(1, 3),
661size=st.integers(7, 10),
662input_channels=st.integers(1, 8),
663output_channels=st.integers(1, 8),
664batch_size=st.integers(0, 3),
665use_bias=st.booleans(),
666**hu.gcs
667)
668@settings(deadline=None, max_examples=50)
669def test_convolution_layout(
670self,
671op_type,
672stride,
673pad,
674kernel,
675dilation,
676size,
677input_channels,
678output_channels,
679batch_size,
680use_bias,
681gc,
682dc,
683):
684assume(size >= dilation * (kernel - 1) + 1)
685
686X = (
687np.random.rand(batch_size, size, size, input_channels).astype(np.float32)
688- 0.5
689)
690w = (
691np.random.rand(output_channels, kernel, kernel, input_channels).astype(
692np.float32
693)
694- 0.5
695)
696b = np.random.rand(output_channels).astype(np.float32) - 0.5
697Output = collections.namedtuple("Output", ["Y", "engine", "order"])
698outputs = []
699
700for order in ["NCHW", "NHWC"]:
701engine_list = [""]
702if hiputl.run_in_hip(gc, dc):
703if order == "NCHW":
704engine_list.append("MIOPEN")
705else:
706if _cudnn_supports(dilation=(dilation > 1), nhwc=(order == "NHWC")):
707engine_list.append("CUDNN")
708
709for engine in engine_list:
710op = core.CreateOperator(
711op_type,
712["X", "w", "b"] if use_bias else ["X", "w"],
713["Y"],
714stride=stride,
715kernel=kernel,
716dilation=dilation,
717pad=pad,
718order=order,
719engine=engine,
720device_option=gc,
721exhaustive_search=True,
722)
723if order == "NCHW":
724X_f = utils.NHWC2NCHW(X)
725w_f = utils.NHWC2NCHW(w)
726else:
727X_f = X
728w_f = w
729self.assertDeviceChecks(
730dc, op, [X_f, w_f, b] if use_bias else [X_f, w_f], [0]
731)
732self.ws.create_blob("X").feed(X_f, device_option=gc)
733self.ws.create_blob("w").feed(w_f, device_option=gc)
734self.ws.create_blob("b").feed(b, device_option=gc)
735self.ws.run(op)
736outputs.append(
737Output(Y=self.ws.blobs["Y"].fetch(), engine=engine, order=order)
738)
739
740def canonical(o):
741if o.order == "NHWC":
742return utils.NHWC2NCHW(o.Y)
743else:
744return o.Y
745
746for o in outputs:
747np.testing.assert_allclose(
748canonical(outputs[0]), canonical(o), atol=1e-4, rtol=1e-4
749)
750
751@given(
752num_workers=st.integers(1, 4),
753net_type=st.sampled_from(
754["simple", "dag"]
755+ (
756["async_dag"]
757if workspace.has_gpu_support
758else []
759)
760),
761engine=st.sampled_from(["CUDNN", ""]),
762**hu.gcs_no_hip
763)
764@settings(deadline=None)
765def test_convolution_sync(self, net_type, num_workers, engine, gc, dc):
766m = ModelHelper(name="test_model")
767n = 1
768d = 2
769depth = 3
770iters = 5
771h = 5
772w = 5
773workspace.ResetWorkspace()
774
775use_cudnn = engine == "CUDNN"
776
777np.random.seed(1701)
778# Build a binary tree of conv layers, summing at each node.
779for i in reversed(range(depth)):
780for j in range(2 ** i):
781bottom_1 = "{}_{}".format(i + 1, 2 * j)
782bottom_2 = "{}_{}".format(i + 1, 2 * j + 1)
783mid_1 = "{}_{}_m".format(i + 1, 2 * j)
784mid_2 = "{}_{}_m".format(i + 1, 2 * j + 1)
785top = "{}_{}".format(i, j)
786w1, b1, w2, b2 = np.random.randn(4).tolist()
787brew.conv(
788m,
789bottom_1,
790mid_1,
791dim_in=d,
792dim_out=d,
793kernel=3,
794weight_init=("ConstantFill", {"value": w1}),
795bias_init=("ConstantFill", {"value": b1}),
796cudnn_state=np.random.randint(0, 3),
797stride=1,
798pad=1,
799deterministic=1,
800use_cudnn=use_cudnn,
801engine=engine,
802)
803brew.conv(
804m,
805bottom_2,
806mid_2,
807dim_in=d,
808dim_out=d,
809kernel=3,
810stride=1,
811pad=1,
812weight_init=("ConstantFill", {"value": w2}),
813bias_init=("ConstantFill", {"value": b2}),
814deterministic=1,
815cudnn_state=np.random.randint(0, 3),
816use_cudnn=use_cudnn,
817engine=engine,
818)
819m.net.Sum([mid_1, mid_2], top)
820
821m.net.Flatten(["0_0"], ["0_0_flat"])
822m.net.SquaredL2Distance(["0_0_flat", "label"], "xent")
823m.net.AveragedLoss("xent", "loss")
824input_to_grad = m.AddGradientOperators(["loss"])
825m.Proto().device_option.CopyFrom(gc)
826m.param_init_net.Proto().device_option.CopyFrom(gc)
827m.Proto().type = net_type
828m.Proto().num_workers = num_workers
829self.ws.run(m.param_init_net)
830
831def run():
832import numpy as np
833
834np.random.seed(1701)
835input_blobs = ["{}_{}".format(depth, j) for j in range(2 ** depth)]
836for input_blob in input_blobs:
837self.ws.create_blob(input_blob).feed(
838np.random.randn(n, d, h, w).astype(np.float32), device_option=gc
839)
840self.ws.create_blob("label").feed(
841np.random.randn(n, d * h * w).astype(np.float32), device_option=gc
842)
843self.ws.run(m.net)
844gradients = [
845self.ws.blobs[str(input_to_grad[input_blob])].fetch()
846for input_blob in input_blobs
847]
848return gradients
849
850outputs = [run() for _ in range(iters)]
851for output in outputs[1:]:
852np.testing.assert_array_equal(outputs[0], output)
853np.testing.assert_allclose(
854np.sum(np.square(output)), 1763719461732352.0, rtol=1e-5
855)
856
857def test_use_cudnn_engine_interactions(self):
858"""Make sure the use_cudnn and engine kwargs work as expected."""
859for model_default in [None, True, False]:
860arg_scope = {}
861if model_default is not None:
862arg_scope["use_cudnn"] = model_default
863else:
864model_default = True # the default
865
866model = ModelHelper(arg_scope=arg_scope)
867self.assertEqual(model.arg_scope["use_cudnn"], model_default)
868f = functools.partial(brew.conv, model, "conv_in", "conv_out", 10, 10, 5)
869
870for op_cudnn in [None, True, False]:
871for op_engine in [None, "", "CUDNN"]:
872kwargs = {}
873if op_cudnn is not None:
874kwargs["use_cudnn"] = op_cudnn
875else:
876op_cudnn = False # the default
877if op_engine is not None:
878kwargs["engine"] = op_engine
879
880calculated_cudnn = kwargs.get("use_cudnn", model_default)
881expected_engine = kwargs.get(
882"engine", "CUDNN" if calculated_cudnn else ""
883)
884
885if (calculated_cudnn is False and op_engine == "CUDNN") or (
886calculated_cudnn is True and op_engine == ""
887):
888with self.assertRaises(ValueError):
889f(**kwargs)
890else:
891f(**kwargs)
892self.assertEqual(model.Proto().op[-1].engine, expected_engine)
893
894@given(
895op_type=st.sampled_from(["Conv", "Conv2D"]),
896N=st.integers(0, 3),
897G=st.integers(1, 3),
898DX=st.integers(1, 3),
899DY=st.integers(1, 3),
900H=st.integers(1, 3),
901W=st.integers(1, 3),
902use_bias=st.booleans(),
903order=st.sampled_from(["NCHW", "NHWC"]),
904force_algo_fwd=_cudnn_convolution_algo_count("fwd"),
905force_algo_dgrad=_cudnn_convolution_algo_count("dgrad"),
906force_algo_wgrad=_cudnn_convolution_algo_count("wgrad"),
907**hu.gcs
908)
909@settings(deadline=10000)
910def test_1x1_conv(
911self,
912op_type,
913N,
914G,
915DX,
916DY,
917H,
918W,
919use_bias,
920order,
921force_algo_fwd,
922force_algo_dgrad,
923force_algo_wgrad,
924gc,
925dc,
926):
927if hiputl.run_in_hip(gc, dc):
928assume(order == "NCHW")
929if order == "NHWC":
930G = 1
931
932C = G * DX
933M = G * DY
934
935op = core.CreateOperator(
936op_type,
937["X", "filter", "bias"] if use_bias else ["X", "filter"],
938["Y"],
939stride_h=1,
940stride_w=1,
941pad_t=0,
942pad_l=0,
943pad_b=0,
944pad_r=0,
945kernel=1,
946order=order,
947group=G,
948force_algo_fwd=force_algo_fwd,
949force_algo_dgrad=force_algo_dgrad,
950force_algo_wgrad=force_algo_wgrad,
951)
952
953if order == "NCHW":
954X = np.random.randn(N, C, H, W).astype(np.float32)
955filter = np.random.randn(M, DX, 1, 1).astype(np.float32)
956else:
957X = np.random.randn(N, H, W, C).astype(np.float32)
958filter = np.random.randn(M, 1, 1, DX).astype(np.float32)
959bias = np.random.randn(M).astype(np.float32)
960inputs = [X, filter, bias] if use_bias else [X, filter]
961
962def conv_1x1_nchw_ref(X, filter, bias=None):
963if N == 0:
964Y = np.zeros(shape=(N, M, H, W), dtype=np.float32)
965return [Y]
966
967X = X.reshape(N, G, DX, -1)
968filter = filter.reshape(G, DY, DX)
969Y = np.zeros(shape=(N, G, DY, H * W), dtype=np.float32)
970for i in range(N):
971for j in range(G):
972Y[i, j, :, :] = np.dot(filter[j, :, :], X[i, j, :, :])
973Y = Y.reshape(N, M, H, W)
974if bias is not None:
975bias = bias.reshape(1, M, 1, 1)
976Y = np.add(Y, bias)
977return [Y]
978
979def conv_1x1_nhwc_ref(X, filter, bias=None):
980if N == 0:
981Y = np.zeros(shape=(N, H, W, M), dtype=np.float32)
982return [Y]
983
984X = X.reshape(N, -1, G, DX)
985filter = filter.reshape(G, DY, DX)
986Y = np.zeros(shape=(N, H * W, G, DY), dtype=np.float32)
987for i in range(N):
988for j in range(G):
989Y[i, :, j, :] = np.dot(X[i, :, j, :], filter[j, :, :].transpose())
990Y = Y.reshape(N, H, W, M)
991if bias is not None:
992bias = bias.reshape(1, 1, 1, M)
993Y = np.add(Y, bias)
994return [Y]
995
996if order == "NCHW":
997conv_1x1_ref = conv_1x1_nchw_ref
998else:
999conv_1x1_ref = conv_1x1_nhwc_ref
1000self.assertReferenceChecks(
1001device_option=gc, op=op, inputs=inputs, reference=conv_1x1_ref
1002)
1003self.assertDeviceChecks(dc, op, inputs, [0])
1004for i in range(len(inputs)):
1005self.assertGradientChecks(gc, op, inputs, i, [0])
1006
1007
1008if __name__ == "__main__":
1009unittest.main()
1010