intel-extension-for-pytorch

Форк
0
586 строк · 22.4 Кб
1
import unittest
2
import torch
3
import torch.nn as nn
4
from common_utils import TestCase
5
import time
6
import torch.nn.functional as F
7
import os
8

9

10
def nms(dets, scores, threshold, sorted=False):
11
    return torch.ops.torch_ipex.nms(dets, scores, threshold, sorted)
12

13

14
batch_score_nms = torch.ops.torch_ipex.batch_score_nms
15
parallel_scale_back_batch = torch.ops.torch_ipex.parallel_scale_back_batch
16
rpn_nms = torch.ops.torch_ipex.rpn_nms
17
box_head_nms = torch.ops.torch_ipex.box_head_nms
18

19

20
def get_rand_seed():
21
    return int(time.time() * 1000000000)
22

23

24
# This function is from https://github.com/kuangliu/pytorch-ssd.
25
def calc_iou_tensor(box1, box2):
26
    """Calculation of IoU based on two boxes tensor,
27
    Reference to https://github.com/kuangliu/pytorch-ssd
28
    input:
29
        box1 (N, 4)
30
        box2 (M, 4)
31
    output:
32
        IoU (N, M)
33
    """
34
    N = box1.size(0)
35
    M = box2.size(0)
36
    be1 = box1.unsqueeze(1).expand(-1, M, -1)
37
    be2 = box2.unsqueeze(0).expand(N, -1, -1)
38
    # Left Top & Right Bottom
39
    lt = torch.max(be1[:, :, :2], be2[:, :, :2])
40
    # mask1 = (be1[:,:, 0] < be2[:,:, 0]) ^ (be1[:,:, 1] < be2[:,:, 1])
41
    # mask1 = ~mask1
42
    rb = torch.min(be1[:, :, 2:], be2[:, :, 2:])
43
    # mask2 = (be1[:,:, 2] < be2[:,:, 2]) ^ (be1[:,:, 3] < be2[:,:, 3])
44
    # mask2 = ~mask2
45
    delta = rb - lt
46
    delta[delta < 0] = 0
47
    intersect = delta[:, :, 0] * delta[:, :, 1]
48
    # *mask1.float()*mask2.float()
49
    delta1 = be1[:, :, 2:] - be1[:, :, :2]
50
    area1 = delta1[:, :, 0] * delta1[:, :, 1]
51
    delta2 = be2[:, :, 2:] - be2[:, :, :2]
52
    area2 = delta2[:, :, 0] * delta2[:, :, 1]
53
    iou = intersect / (area1 + area2 - intersect)
54
    return iou
55

56

57
class TestScaleBackBatch(TestCase):
58
    def scale_back_batch(self, bboxes_in, scores_in, dboxes_xywh, scale_xy, scale_wh):
59
        """
60
        Python implementation of Encoder::scale_back_batch, refer to \
61
            https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
62
        """
63
        bboxes_in[:, :, :2] = scale_xy * bboxes_in[:, :, :2]
64
        bboxes_in[:, :, 2:] = scale_wh * bboxes_in[:, :, 2:]
65
        bboxes_in[:, :, :2] = (
66
            bboxes_in[:, :, :2] * dboxes_xywh[:, :, 2:] + dboxes_xywh[:, :, :2]
67
        )
68
        bboxes_in[:, :, 2:] = bboxes_in[:, :, 2:].exp() * dboxes_xywh[:, :, 2:]
69
        # Transform format to ltrb
70
        l, t, r, b = (
71
            bboxes_in[:, :, 0] - 0.5 * bboxes_in[:, :, 2],
72
            bboxes_in[:, :, 1] - 0.5 * bboxes_in[:, :, 3],
73
            bboxes_in[:, :, 0] + 0.5 * bboxes_in[:, :, 2],
74
            bboxes_in[:, :, 1] + 0.5 * bboxes_in[:, :, 3],
75
        )
76
        bboxes_in[:, :, 0] = l
77
        bboxes_in[:, :, 1] = t
78
        bboxes_in[:, :, 2] = r
79
        bboxes_in[:, :, 3] = b
80
        return bboxes_in, F.softmax(scores_in, dim=-1)
81

82
    def test_scale_back_batch_result(self):
83
        batch_size = 16
84
        number_boxes = 1024
85
        scale_xy = 0.1
86
        scale_wh = 0.2
87
        predicted_loc = (
88
            torch.randn((batch_size, number_boxes, 4)).contiguous().to(torch.float32)
89
        )
90
        predicted_score = (
91
            torch.randn((batch_size, number_boxes, 81)).contiguous().to(torch.float32)
92
        )
93
        dboxes_xywh = torch.randn((1, number_boxes, 4)).contiguous().to(torch.float64)
94
        bbox_res1, score_res1 = self.scale_back_batch(
95
            predicted_loc.clone(),
96
            predicted_score.clone(),
97
            dboxes_xywh.clone(),
98
            scale_xy,
99
            scale_wh,
100
        )
101
        bbox_res2, score_res2 = parallel_scale_back_batch(
102
            predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh
103
        )
104
        # test autocast
105
        with torch.cpu.amp.autocast():
106
            bbox_res3, score_res3 = parallel_scale_back_batch(
107
                predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh
108
            )
109
        self.assertTrue(torch.allclose(bbox_res1, bbox_res2, rtol=1e-4, atol=1e-4))
110
        self.assertTrue(torch.allclose(bbox_res1, bbox_res3, rtol=1e-4, atol=1e-4))
111
        self.assertTrue(torch.allclose(score_res1, score_res2, rtol=1e-4, atol=1e-4))
112
        self.assertTrue(torch.allclose(score_res1, score_res3, rtol=1e-4, atol=1e-4))
113

114
        # test double
115
        bbox_res4, score_res4 = parallel_scale_back_batch(
116
            predicted_loc.clone().double(),
117
            predicted_score,
118
            dboxes_xywh,
119
            scale_xy,
120
            scale_wh,
121
        )
122
        self.assertEqual(bbox_res4, bbox_res2)
123
        self.assertEqual(score_res4, score_res2)
124
        self.assertTrue(bbox_res4.dtype == torch.float64)
125

126

127
class TestNMS(TestCase):
128
    def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200):
129
        """
130
        Python implementation of Encoder::decode_single, refer to \
131
            https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py
132
        """
133
        # perform non-maximum suppression
134
        # Reference to https://github.com/amdegroot/ssd.pytorch
135

136
        bboxes_out = []
137
        scores_out = []
138
        labels_out = []
139
        for i, score in enumerate(scores_in.split(1, 1)):
140
            # skip background
141
            # print(score[score>0.90])
142
            if i == 0:
143
                continue
144
            score = score.squeeze(1)
145
            mask = score > 0.05
146
            bboxes, score = bboxes_in[mask, :], score[mask]
147
            if score.size(0) == 0:
148
                continue
149
            score_sorted, score_idx_sorted = score.sort(dim=0)
150
            # select max_output indices
151
            score_idx_sorted = score_idx_sorted[-max_num:]
152
            candidates = []
153
            while score_idx_sorted.numel() > 0:
154
                idx = score_idx_sorted[-1].item()
155
                bboxes_sorted = bboxes[score_idx_sorted, :]
156
                bboxes_idx = bboxes[idx, :].unsqueeze(dim=0)
157
                iou_sorted = calc_iou_tensor(bboxes_sorted, bboxes_idx).squeeze()
158
                # we only need iou < criteria
159
                score_idx_sorted = score_idx_sorted[iou_sorted < criteria]
160
                candidates.append(idx)
161

162
            bboxes_out.append(bboxes[candidates, :])
163
            scores_out.append(score[candidates])
164
            labels_out.extend([i] * len(candidates))
165
        bboxes_out, labels_out, scores_out = (
166
            torch.cat(bboxes_out, dim=0),
167
            torch.tensor(labels_out, dtype=torch.long),
168
            torch.cat(scores_out, dim=0),
169
        )
170
        _, max_ids = scores_out.sort(dim=0)
171
        max_ids = max_ids[-max_output:]
172
        return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids]
173

174
    def test_batch_nms_result(self):
175
        batch_size = 1
176
        number_boxes = 15130
177
        scale_xy = 0.1
178
        scale_wh = 0.2
179
        criteria = 0.50
180
        max_output = 200
181
        predicted_loc = torch.load(
182
            os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")
183
        )  # sizes: [1, 15130, 4]
184
        predicted_score = torch.load(
185
            os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")
186
        )  # sizes: [1, 15130, 81]
187
        dboxes_xywh = torch.load(
188
            os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")
189
        )
190
        bboxes, probs = parallel_scale_back_batch(
191
            predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh
192
        )
193
        bboxes_clone = bboxes.clone()
194
        probs_clone = probs.clone()
195

196
        output = []
197
        for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):
198
            bbox = bbox.squeeze(0)
199
            prob = prob.squeeze(0)
200
            output.append(self.decode_single(bbox, prob, criteria, max_output))
201
        output2_raw = batch_score_nms(bboxes_clone, probs_clone, criteria, max_output)
202

203
        # test autocast
204
        with torch.cpu.amp.autocast():
205
            for datatype in (torch.bfloat16, torch.float32):
206
                bboxes_autocast = bboxes.clone().to(datatype)
207
                probs_autocast = probs.clone().to(datatype)
208
                output2_raw_autocast = batch_score_nms(
209
                    bboxes_autocast, probs_autocast, criteria, max_output
210
                )
211
                for i in range(3):
212
                    self.assertTrue(output2_raw_autocast[i].dtype == torch.float32)
213

214
        # Re-assembly the result
215
        output2 = []
216
        idx = 0
217
        for i in range(output2_raw[3].size(0)):
218
            output2.append(
219
                (
220
                    output2_raw[0][idx : idx + output2_raw[3][i]],
221
                    output2_raw[1][idx : idx + output2_raw[3][i]],
222
                    output2_raw[2][idx : idx + output2_raw[3][i]],
223
                )
224
            )
225
            idx += output2_raw[3][i]
226

227
        for i in range(batch_size):
228
            loc, label, prob = list(r for r in output[i])
229
            loc2, label2, prob2 = list(r for r in output2[i])
230
            self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))
231
            self.assertEqual(label, label2)
232
            self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))
233

234
        # test double
235
        output2_raw_double = batch_score_nms(
236
            bboxes.clone().double(), probs.clone().double(), criteria, max_output
237
        )
238
        self.assertEqual(output2_raw_double, output2_raw)
239
        self.assertTrue(output2_raw_double[0].dtype == torch.float64)
240

241
    def test_jit_trace_batch_nms(self):
242
        class Batch_NMS(nn.Module):
243
            def __init__(self, criteria, max_output):
244
                super(Batch_NMS, self).__init__()
245
                self.criteria = criteria
246
                self.max_output = max_output
247

248
            def forward(self, bboxes_clone, probs_clone):
249
                return batch_score_nms(
250
                    bboxes_clone, probs_clone, self.criteria, self.max_output
251
                )
252

253
        batch_size = 1
254
        number_boxes = 15130
255
        scale_xy = 0.1
256
        scale_wh = 0.2
257
        criteria = 0.50
258
        max_output = 200
259
        predicted_loc = torch.load(
260
            os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")
261
        )  # sizes: [1, 15130, 4]
262
        predicted_score = torch.load(
263
            os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")
264
        )  # sizes: [1, 15130, 81]
265
        dboxes_xywh = torch.load(
266
            os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")
267
        )
268
        bboxes, probs = parallel_scale_back_batch(
269
            predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh
270
        )
271
        bboxes_clone = bboxes.clone()
272
        probs_clone = probs.clone()
273

274
        output = []
275
        for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)):
276
            bbox = bbox.squeeze(0)
277
            prob = prob.squeeze(0)
278
            output.append(self.decode_single(bbox, prob, criteria, max_output))
279

280
        batch_score_nms_module = Batch_NMS(criteria, max_output)
281
        model_decode = torch.jit.trace(
282
            batch_score_nms_module, (bboxes_clone, probs_clone)
283
        )
284
        output2_raw = model_decode(bboxes_clone, probs_clone)
285

286
        # Re-assembly the result
287
        output2 = []
288
        idx = 0
289
        for i in range(output2_raw[3].size(0)):
290
            output2.append(
291
                (
292
                    output2_raw[0][idx : idx + output2_raw[3][i]],
293
                    output2_raw[1][idx : idx + output2_raw[3][i]],
294
                    output2_raw[2][idx : idx + output2_raw[3][i]],
295
                )
296
            )
297
            idx += output2_raw[3][i]
298

299
        for i in range(batch_size):
300
            loc, label, prob = list(r for r in output[i])
301
            loc2, label2, prob2 = list(r for r in output2[i])
302
            self.assertTrue(torch.allclose(loc, loc2, rtol=1e-4, atol=1e-4))
303
            self.assertEqual(label, label2)
304
            self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))
305

306
    def test_nms_kernel_result(self):
307
        batch_size = 1
308
        class_number = 81
309
        scale_xy = 0.1
310
        scale_wh = 0.2
311
        criteria = 0.50
312
        max_output = 200
313
        predicted_loc = torch.load(
314
            os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")
315
        )  # sizes: [1, 15130, 4]
316
        predicted_score = torch.load(
317
            os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")
318
        )  # sizes: [1, 15130, 81]
319
        dboxes_xywh = torch.load(
320
            os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt")
321
        )
322
        bboxes, probs = parallel_scale_back_batch(
323
            predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh
324
        )
325

326
        for bs in range(batch_size):
327
            loc = bboxes[bs].squeeze(0)
328
            for class_id in range(class_number):
329
                if class_id == 0:
330
                    # Skip the background
331
                    continue
332
                score = probs[bs, :, class_id]
333

334
                score_sorted, indices = torch.sort(score, descending=True)
335
                loc_sorted = torch.index_select(loc, 0, indices)
336

337
                result = nms(loc_sorted.clone(), score_sorted.clone(), criteria, True)
338
                result_ref = nms(loc.clone(), score.clone(), criteria, False)
339
                result_ref2 = nms(
340
                    loc_sorted.clone().to(dtype=torch.float64),
341
                    score_sorted.clone().to(dtype=torch.float64),
342
                    criteria,
343
                    True,
344
                )
345

346
                bbox_keep, _ = torch.sort(
347
                    torch.index_select(loc_sorted, 0, result).squeeze(0), 0
348
                )
349
                bbox_keep_ref, _ = torch.sort(
350
                    torch.index_select(loc, 0, result_ref).squeeze(0), 0
351
                )
352
                bbox_keep_ref2, _ = torch.sort(
353
                    torch.index_select(loc_sorted, 0, result_ref2).squeeze(0), 0
354
                )
355

356
                score_keep, _ = torch.sort(
357
                    torch.index_select(score_sorted, 0, result).squeeze(0), 0
358
                )
359
                score_keep_ref, _ = torch.sort(
360
                    torch.index_select(score, 0, result_ref).squeeze(0), 0
361
                )
362
                score_keep_ref2, _ = torch.sort(
363
                    torch.index_select(score_sorted, 0, result_ref2).squeeze(0), 0
364
                )
365

366
                self.assertEqual(result.size(0), result_ref.size(0))
367
                self.assertTrue(
368
                    torch.allclose(bbox_keep, bbox_keep_ref, rtol=1e-4, atol=1e-4)
369
                )
370
                self.assertTrue(
371
                    torch.allclose(score_keep, score_keep_ref, rtol=1e-4, atol=1e-4)
372
                )
373
                self.assertTrue(
374
                    torch.allclose(bbox_keep, bbox_keep_ref2, rtol=1e-4, atol=1e-4)
375
                )
376
                self.assertTrue(
377
                    torch.allclose(score_keep, score_keep_ref2, rtol=1e-4, atol=1e-4)
378
                )
379

380
                # test autocast
381
                with torch.cpu.amp.autocast():
382
                    result_autocast = nms(loc.clone(), score.clone(), criteria, False)
383
                    self.assertEqual(result_autocast, result_ref)
384

385
                # test double
386
                result_double = nms(
387
                    loc.clone().double(), score.clone().double(), criteria, False
388
                )
389
                self.assertEqual(result_double, result_ref)
390

391
    def test_rpn_nms_result(self):
392
        image_shapes = [(800, 824), (800, 1199)]
393
        min_size = 0
394
        nms_thresh = 0.7
395
        post_nms_top_n = 1000
396
        proposals = torch.load(
397
            os.path.join(os.path.dirname(__file__), "data/rpn_nms_proposals.pt")
398
        )
399
        objectness = torch.load(
400
            os.path.join(os.path.dirname(__file__), "data/rpn_nms_objectness.pt")
401
        )
402

403
        new_proposal = []
404
        new_score = []
405
        for proposal, score, im_shape in zip(
406
            proposals.clone(), objectness.clone(), image_shapes
407
        ):
408
            proposal[:, 0].clamp_(min=0, max=im_shape[0] - 1)
409
            proposal[:, 1].clamp_(min=0, max=im_shape[1] - 1)
410
            proposal[:, 2].clamp_(min=0, max=im_shape[0] - 1)
411
            proposal[:, 3].clamp_(min=0, max=im_shape[1] - 1)
412
            keep = (
413
                (
414
                    (proposal[:, 2] - proposal[:, 0] >= min_size)
415
                    & (proposal[:, 3] - proposal[:, 1] >= min_size)
416
                )
417
                .nonzero()
418
                .squeeze(1)
419
            )
420
            proposal = proposal[keep]
421
            score = score[keep]
422
            if nms_thresh > 0:
423
                keep = nms(proposal, score, nms_thresh)
424
                if post_nms_top_n > 0:
425
                    keep = keep[:post_nms_top_n]
426
            new_proposal.append(proposal[keep])
427
            new_score.append(score[keep])
428

429
        new_proposal_, new_score_ = rpn_nms(
430
            proposals, objectness, image_shapes, min_size, nms_thresh, post_nms_top_n
431
        )
432

433
        self.assertEqual(new_proposal, new_proposal_)
434
        self.assertEqual(new_score, new_score_)
435

436
        # test autocast
437
        with torch.cpu.amp.autocast():
438
            for datatype in (torch.bfloat16, torch.float32):
439
                proposals_autocast = proposals.clone().to(datatype)
440
                objectness_autocast = objectness.clone().to(datatype)
441
                new_proposal_autocast, new_score_autocast = rpn_nms(
442
                    proposals_autocast,
443
                    objectness_autocast,
444
                    image_shapes,
445
                    min_size,
446
                    nms_thresh,
447
                    post_nms_top_n,
448
                )
449
                self.assertTrue(new_proposal_autocast[0].dtype == torch.float32)
450
                self.assertTrue(new_score_autocast[0].dtype == torch.float32)
451

452
        # test double
453
        new_proposal_double, new_score_double = rpn_nms(
454
            proposals.clone().double(),
455
            objectness.clone().double(),
456
            image_shapes,
457
            min_size,
458
            nms_thresh,
459
            post_nms_top_n,
460
        )
461
        self.assertEqual(new_proposal_double, new_proposal)
462
        self.assertEqual(new_score_double, new_score)
463
        self.assertTrue(new_proposal_double[0].dtype == torch.float64)
464
        self.assertTrue(new_score_double[0].dtype == torch.float64)
465

466
    def test_box_head_nms_result(self):
467
        image_shapes = [(800, 824), (800, 1199)]
468
        score_thresh = 0.05
469
        nms_ = 0.5
470
        detections_per_img = 100
471
        num_classes = 81
472
        proposals = torch.load(
473
            os.path.join(os.path.dirname(__file__), "data/box_head_nms_proposals.pt")
474
        )
475
        class_prob = torch.load(
476
            os.path.join(os.path.dirname(__file__), "data/box_head_nms_class_prob.pt")
477
        )
478

479
        boxes_out = []
480
        scores_out = []
481
        labels_out = []
482
        for scores, boxes, image_shape in zip(class_prob, proposals, image_shapes):
483
            boxes = boxes.reshape(-1, 4)
484
            boxes[:, 0].clamp_(min=0, max=image_shape[0] - 1)
485
            boxes[:, 1].clamp_(min=0, max=image_shape[1] - 1)
486
            boxes[:, 2].clamp_(min=0, max=image_shape[0] - 1)
487
            boxes[:, 3].clamp_(min=0, max=image_shape[1] - 1)
488
            boxes = boxes.reshape(-1, num_classes * 4)
489
            scores = scores.reshape(-1, num_classes)
490

491
            inds_all = scores > score_thresh
492
            new_boxes = []
493
            new_scores = []
494
            new_labels = []
495
            for j in range(1, num_classes):
496
                inds = inds_all[:, j].nonzero().squeeze(1)
497
                scores_j = scores[inds, j]
498
                boxes_j = boxes[inds, j * 4 : (j + 1) * 4]
499
                if nms_ > 0:
500
                    keep = nms(boxes_j, scores_j, nms_)
501
                new_boxes.append(boxes_j[keep])
502
                new_scores.append(scores_j[keep])
503
                new_labels.append(torch.full((len(keep),), j, dtype=torch.int64))
504

505
            new_boxes, new_scores, new_labels = (
506
                torch.cat(new_boxes, dim=0),
507
                torch.cat(new_scores, dim=0),
508
                torch.cat(new_labels, dim=0),
509
            )
510
            number_of_detections = new_boxes.size(0)
511
            if number_of_detections > detections_per_img > 0:
512
                image_thresh, _ = torch.kthvalue(
513
                    new_scores, number_of_detections - detections_per_img + 1
514
                )
515
                keep = new_scores >= image_thresh.item()
516
                keep = torch.nonzero(keep).squeeze(1)
517
                boxes_out.append(new_boxes[keep])
518
                scores_out.append(new_scores[keep])
519
                labels_out.append(new_labels[keep])
520
            else:
521
                boxes_out.append(new_boxes)
522
                scores_out.append(new_scores)
523
                labels_out.append(new_labels)
524

525
        boxes_out_, scores_out_, labels_out_ = box_head_nms(
526
            proposals,
527
            class_prob,
528
            image_shapes,
529
            score_thresh,
530
            nms_,
531
            detections_per_img,
532
            num_classes,
533
        )
534

535
        self.assertEqual(boxes_out, boxes_out_)
536
        self.assertEqual(scores_out, scores_out_)
537
        self.assertEqual(labels_out, labels_out_)
538

539
        # test autocast
540
        with torch.cpu.amp.autocast():
541
            for datatype in (torch.bfloat16, torch.float32):
542
                proposals_autocast = (
543
                    proposals[0].to(datatype),
544
                    proposals[1].to(datatype),
545
                )
546
                class_prob_autocast = (
547
                    class_prob[0].to(datatype),
548
                    class_prob[1].to(datatype),
549
                )
550
                (
551
                    boxes_out_autocast,
552
                    scores_out_autocast,
553
                    labels_out_autocast,
554
                ) = box_head_nms(
555
                    proposals_autocast,
556
                    class_prob_autocast,
557
                    image_shapes,
558
                    score_thresh,
559
                    nms_,
560
                    detections_per_img,
561
                    num_classes,
562
                )
563
                self.assertTrue(boxes_out_autocast[0].dtype == torch.float32)
564
                self.assertTrue(scores_out_autocast[0].dtype == torch.float32)
565

566
        # test double
567
        proposals_double = (proposals[0].double(), proposals[1].double())
568
        class_prob_double = (class_prob[0].double(), class_prob[1].double())
569
        boxes_out_double, scores_out_double, labels_out_double = box_head_nms(
570
            proposals_double,
571
            class_prob_double,
572
            image_shapes,
573
            score_thresh,
574
            nms_,
575
            detections_per_img,
576
            num_classes,
577
        )
578
        self.assertEqual(boxes_out_double, boxes_out)
579
        self.assertEqual(scores_out_double, scores_out)
580
        self.assertEqual(labels_out_double, labels_out)
581
        self.assertTrue(boxes_out_double[0].dtype == torch.float64)
582
        self.assertTrue(scores_out_double[0].dtype == torch.float64)
583

584

585
if __name__ == "__main__":
586
    test = unittest.main()
587

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.