pytorch

Форк
0
/
test_indexing.py 
1682 строки · 73.4 Кб
1
# Owner(s): ["module: tests"]
2

3
import torch
4
from torch import tensor
5

6
import unittest
7
import warnings
8
import random
9
from functools import reduce
10

11
import numpy as np
12

13
from torch.testing import make_tensor
14
from torch.testing._internal.common_utils import (
15
    TestCase, run_tests, skipIfTorchDynamo, DeterministicGuard)
16
from torch.testing._internal.common_device_type import (
17
    instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
18
    onlyNativeDeviceTypes, skipXLA)
19
import operator
20

21

22
class TestIndexing(TestCase):
23
    def test_index(self, device):
24

25
        def consec(size, start=1):
26
            sequence = torch.ones(torch.tensor(size).prod(0)).cumsum(0)
27
            sequence.add_(start - 1)
28
            return sequence.view(*size)
29

30
        reference = consec((3, 3, 3)).to(device)
31

32
        # empty tensor indexing
33
        self.assertEqual(reference[torch.LongTensor().to(device)], reference.new(0, 3, 3))
34

35
        self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0)
36
        self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0)
37
        self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0)
38
        self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0)
39
        self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)
40
        self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0)
41
        self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0)
42

43
        # indexing with Ellipsis
44
        self.assertEqual(reference[..., 2], torch.tensor([[3., 6., 9.],
45
                                                          [12., 15., 18.],
46
                                                          [21., 24., 27.]]), atol=0, rtol=0)
47
        self.assertEqual(reference[0, ..., 2], torch.tensor([3., 6., 9.]), atol=0, rtol=0)
48
        self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0)
49
        self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0)
50
        self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0)
51
        self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0)
52
        self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0)
53
        self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0)
54
        self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0)
55
        self.assertEqual(reference[...], reference, atol=0, rtol=0)
56

57
        reference_5d = consec((3, 3, 3, 3, 3)).to(device)
58
        self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0)
59
        self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0)
60
        self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0)
61
        self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0)
62

63
        # LongTensor indexing
64
        reference = consec((5, 5, 5)).to(device)
65
        idx = torch.LongTensor([2, 4]).to(device)
66
        self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
67
        # TODO: enable one indexing is implemented like in numpy
68
        # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
69
        # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
70

71
        # None indexing
72
        self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
73
        self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
74
        self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
75
        self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
76
        self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
77

78
        # indexing 0-length slice
79
        self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
80
        self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
81
        self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
82
        self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
83

84
        # indexing with step
85
        reference = consec((10, 10, 10)).to(device)
86
        self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
87
        self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0))
88
        self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
89
        self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1))
90
        self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
91
        self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
92
        self.assertEqual(reference[:, 2, 1:6:2],
93
                         torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
94

95
        lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
96
        tensor = torch.DoubleTensor(lst).to(device)
97
        for _i in range(100):
98
            idx1_start = random.randrange(10)
99
            idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
100
            idx1_step = random.randrange(1, 8)
101
            idx1 = slice(idx1_start, idx1_end, idx1_step)
102
            if random.randrange(2) == 0:
103
                idx2_start = random.randrange(10)
104
                idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
105
                idx2_step = random.randrange(1, 8)
106
                idx2 = slice(idx2_start, idx2_end, idx2_step)
107
                lst_indexed = [l[idx2] for l in lst[idx1]]
108
                tensor_indexed = tensor[idx1, idx2]
109
            else:
110
                lst_indexed = lst[idx1]
111
                tensor_indexed = tensor[idx1]
112
            self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
113

114
        self.assertRaises(ValueError, lambda: reference[1:9:0])
115
        self.assertRaises(ValueError, lambda: reference[1:9:-1])
116

117
        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
118
        self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
119
        self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
120

121
        self.assertRaises(IndexError, lambda: reference[0.0])
122
        self.assertRaises(TypeError, lambda: reference[0.0:2.0])
123
        self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
124
        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
125
        self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
126
        self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
127

128
        def delitem():
129
            del reference[0]
130

131
        self.assertRaises(TypeError, delitem)
132

133
    @onlyNativeDeviceTypes
134
    @dtypes(torch.half, torch.double)
135
    def test_advancedindex(self, device, dtype):
136
        # Tests for Integer Array Indexing, Part I - Purely integer array
137
        # indexing
138

139
        def consec(size, start=1):
140
            # Creates the sequence in float since CPU half doesn't support the
141
            # needed operations. Converts to dtype before returning.
142
            numel = reduce(operator.mul, size, 1)
143
            sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
144
            sequence.add_(start - 1)
145
            return sequence.view(*size).to(dtype=dtype)
146

147
        # pick a random valid indexer type
148
        def ri(indices):
149
            choice = random.randint(0, 2)
150
            if choice == 0:
151
                return torch.LongTensor(indices).to(device)
152
            elif choice == 1:
153
                return list(indices)
154
            else:
155
                return tuple(indices)
156

157
        def validate_indexing(x):
158
            self.assertEqual(x[[0]], consec((1,)))
159
            self.assertEqual(x[ri([0]), ], consec((1,)))
160
            self.assertEqual(x[ri([3]), ], consec((1,), 4))
161
            self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
162
            self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3))
163
            self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([1, 3, 5], dtype=dtype, device=device))
164

165
        def validate_setting(x):
166
            x[[0]] = -2
167
            self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
168
            x[[0]] = -1
169
            self.assertEqual(x[ri([0]), ], torch.tensor([-1], dtype=dtype, device=device))
170
            x[[2, 3, 4]] = 4
171
            self.assertEqual(x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device))
172
            x[ri([2, 3, 4]), ] = 3
173
            self.assertEqual(x[ri([2, 3, 4]), ], torch.tensor([3, 3, 3], dtype=dtype, device=device))
174
            x[ri([0, 2, 4]), ] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
175
            self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([5, 4, 3], dtype=dtype, device=device))
176

177
        # Only validates indexing and setting for halfs
178
        if dtype == torch.half:
179
            reference = consec((10,))
180
            validate_indexing(reference)
181
            validate_setting(reference)
182
            return
183

184
        # Case 1: Purely Integer Array Indexing
185
        reference = consec((10,))
186
        validate_indexing(reference)
187

188
        # setting values
189
        validate_setting(reference)
190

191
        # Tensor with stride != 1
192
        # strided is [1, 3, 5, 7]
193
        reference = consec((10,))
194
        strided = torch.tensor((), dtype=dtype, device=device)
195
        strided.set_(reference.storage(), storage_offset=0,
196
                     size=torch.Size([4]), stride=[2])
197

198
        self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
199
        self.assertEqual(strided[ri([0]), ], torch.tensor([1], dtype=dtype, device=device))
200
        self.assertEqual(strided[ri([3]), ], torch.tensor([7], dtype=dtype, device=device))
201
        self.assertEqual(strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device))
202
        self.assertEqual(strided[ri([1, 2]), ], torch.tensor([3, 5], dtype=dtype, device=device))
203
        self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
204
                         torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device))
205

206
        # stride is [4, 8]
207
        strided = torch.tensor((), dtype=dtype, device=device)
208
        strided.set_(reference.storage(), storage_offset=4,
209
                     size=torch.Size([2]), stride=[4])
210
        self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
211
        self.assertEqual(strided[ri([0]), ], torch.tensor([5], dtype=dtype, device=device))
212
        self.assertEqual(strided[ri([1]), ], torch.tensor([9], dtype=dtype, device=device))
213
        self.assertEqual(strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device))
214
        self.assertEqual(strided[ri([0, 1]), ], torch.tensor([5, 9], dtype=dtype, device=device))
215
        self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
216
                         torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device))
217

218
        # reference is 1 2
219
        #              3 4
220
        #              5 6
221
        reference = consec((3, 2))
222
        self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.tensor([1, 3, 5], dtype=dtype, device=device))
223
        self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.tensor([2, 4, 6], dtype=dtype, device=device))
224
        self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
225
        self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
226
        self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.tensor([1, 2], dtype=dtype, device=device))
227
        self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
228
                         torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device))
229
        self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
230
                         torch.tensor([1, 2, 3, 3], dtype=dtype, device=device))
231

232
        rows = ri([[0, 0],
233
                   [1, 2]])
234
        columns = [0],
235
        self.assertEqual(reference[rows, columns], torch.tensor([[1, 1],
236
                                                                 [3, 5]], dtype=dtype, device=device))
237

238
        rows = ri([[0, 0],
239
                   [1, 2]])
240
        columns = ri([1, 0])
241
        self.assertEqual(reference[rows, columns], torch.tensor([[2, 1],
242
                                                                 [4, 5]], dtype=dtype, device=device))
243
        rows = ri([[0, 0],
244
                   [1, 2]])
245
        columns = ri([[0, 1],
246
                      [1, 0]])
247
        self.assertEqual(reference[rows, columns], torch.tensor([[1, 2],
248
                                                                 [4, 5]], dtype=dtype, device=device))
249

250
        # setting values
251
        reference[ri([0]), ri([1])] = -1
252
        self.assertEqual(reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device))
253
        reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device)
254
        self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
255
                         torch.tensor([-1, 2, -4], dtype=dtype, device=device))
256
        reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
257
        self.assertEqual(reference[rows, columns],
258
                         torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
259

260
        # Verify still works with Transposed (i.e. non-contiguous) Tensors
261

262
        reference = torch.tensor([[0, 1, 2, 3],
263
                                  [4, 5, 6, 7],
264
                                  [8, 9, 10, 11]], dtype=dtype, device=device).t_()
265

266
        # Transposed: [[0, 4, 8],
267
        #              [1, 5, 9],
268
        #              [2, 6, 10],
269
        #              [3, 7, 11]]
270

271
        self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
272
                         torch.tensor([0, 1, 2], dtype=dtype, device=device))
273
        self.assertEqual(reference[ri([0, 1, 2]), ri([1])],
274
                         torch.tensor([4, 5, 6], dtype=dtype, device=device))
275
        self.assertEqual(reference[ri([0]), ri([0])],
276
                         torch.tensor([0], dtype=dtype, device=device))
277
        self.assertEqual(reference[ri([2]), ri([1])],
278
                         torch.tensor([6], dtype=dtype, device=device))
279
        self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]],
280
                         torch.tensor([0, 4], dtype=dtype, device=device))
281
        self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
282
                         torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device))
283
        self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
284
                         torch.tensor([0, 4, 1, 1], dtype=dtype, device=device))
285

286
        rows = ri([[0, 0],
287
                   [1, 2]])
288
        columns = [0],
289
        self.assertEqual(reference[rows, columns],
290
                         torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device))
291

292
        rows = ri([[0, 0],
293
                   [1, 2]])
294
        columns = ri([1, 0])
295
        self.assertEqual(reference[rows, columns],
296
                         torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device))
297
        rows = ri([[0, 0],
298
                   [1, 3]])
299
        columns = ri([[0, 1],
300
                      [1, 2]])
301
        self.assertEqual(reference[rows, columns],
302
                         torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device))
303

304
        # setting values
305
        reference[ri([0]), ri([1])] = -1
306
        self.assertEqual(reference[ri([0]), ri([1])],
307
                         torch.tensor([-1], dtype=dtype, device=device))
308
        reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device)
309
        self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
310
                         torch.tensor([-1, 2, -4], dtype=dtype, device=device))
311
        reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
312
        self.assertEqual(reference[rows, columns],
313
                         torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
314

315
        # stride != 1
316

317
        # strided is [[1 3 5 7],
318
        #             [9 11 13 15]]
319

320
        reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
321
        strided = torch.tensor((), dtype=dtype, device=device)
322
        strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
323
                     stride=[8, 2])
324

325
        self.assertEqual(strided[ri([0, 1]), ri([0])],
326
                         torch.tensor([1, 9], dtype=dtype, device=device))
327
        self.assertEqual(strided[ri([0, 1]), ri([1])],
328
                         torch.tensor([3, 11], dtype=dtype, device=device))
329
        self.assertEqual(strided[ri([0]), ri([0])],
330
                         torch.tensor([1], dtype=dtype, device=device))
331
        self.assertEqual(strided[ri([1]), ri([3])],
332
                         torch.tensor([15], dtype=dtype, device=device))
333
        self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]],
334
                         torch.tensor([1, 7], dtype=dtype, device=device))
335
        self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
336
                         torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device))
337
        self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
338
                         torch.tensor([1, 3, 9, 9], dtype=dtype, device=device))
339

340
        rows = ri([[0, 0],
341
                   [1, 1]])
342
        columns = [0],
343
        self.assertEqual(strided[rows, columns],
344
                         torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device))
345

346
        rows = ri([[0, 1],
347
                   [1, 0]])
348
        columns = ri([1, 2])
349
        self.assertEqual(strided[rows, columns],
350
                         torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device))
351
        rows = ri([[0, 0],
352
                   [1, 1]])
353
        columns = ri([[0, 1],
354
                      [1, 2]])
355
        self.assertEqual(strided[rows, columns],
356
                         torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device))
357

358
        # setting values
359

360
        # strided is [[10, 11],
361
        #             [17, 18]]
362

363
        reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
364
        strided = torch.tensor((), dtype=dtype, device=device)
365
        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
366
                     stride=[7, 1])
367
        self.assertEqual(strided[ri([0]), ri([1])],
368
                         torch.tensor([11], dtype=dtype, device=device))
369
        strided[ri([0]), ri([1])] = -1
370
        self.assertEqual(strided[ri([0]), ri([1])],
371
                         torch.tensor([-1], dtype=dtype, device=device))
372

373
        reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
374
        strided = torch.tensor((), dtype=dtype, device=device)
375
        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
376
                     stride=[7, 1])
377
        self.assertEqual(strided[ri([0, 1]), ri([1, 0])],
378
                         torch.tensor([11, 17], dtype=dtype, device=device))
379
        strided[ri([0, 1]), ri([1, 0])] = torch.tensor([-1, 2], dtype=dtype, device=device)
380
        self.assertEqual(strided[ri([0, 1]), ri([1, 0])],
381
                         torch.tensor([-1, 2], dtype=dtype, device=device))
382

383
        reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
384
        strided = torch.tensor((), dtype=dtype, device=device)
385
        strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
386
                     stride=[7, 1])
387

388
        rows = ri([[0],
389
                   [1]])
390
        columns = ri([[0, 1],
391
                      [0, 1]])
392
        self.assertEqual(strided[rows, columns],
393
                         torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device))
394
        strided[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
395
        self.assertEqual(strided[rows, columns],
396
                         torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
397

398
        # Tests using less than the number of dims, and ellipsis
399

400
        # reference is 1 2
401
        #              3 4
402
        #              5 6
403
        reference = consec((3, 2))
404
        self.assertEqual(reference[ri([0, 2]), ],
405
                         torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device))
406
        self.assertEqual(reference[ri([1]), ...],
407
                         torch.tensor([[3, 4]], dtype=dtype, device=device))
408
        self.assertEqual(reference[..., ri([1])],
409
                         torch.tensor([[2], [4], [6]], dtype=dtype, device=device))
410

411
        # verify too many indices fails
412
        with self.assertRaises(IndexError):
413
            reference[ri([1]), ri([0, 2]), ri([3])]
414

415
        # test invalid index fails
416
        reference = torch.empty(10, dtype=dtype, device=device)
417
        # can't test cuda because it is a device assert
418
        if not reference.is_cuda:
419
            for err_idx in (10, -11):
420
                with self.assertRaisesRegex(IndexError, r'out of'):
421
                    reference[err_idx]
422
                with self.assertRaisesRegex(IndexError, r'out of'):
423
                    reference[torch.LongTensor([err_idx]).to(device)]
424
                with self.assertRaisesRegex(IndexError, r'out of'):
425
                    reference[[err_idx]]
426

427
        def tensor_indices_to_np(tensor, indices):
428
            # convert the Torch Tensor to a numpy array
429
            tensor = tensor.to(device='cpu')
430
            npt = tensor.numpy()
431

432
            # convert indices
433
            idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else
434
                         i for i in indices)
435

436
            return npt, idxs
437

438
        def get_numpy(tensor, indices):
439
            npt, idxs = tensor_indices_to_np(tensor, indices)
440

441
            # index and return as a Torch Tensor
442
            return torch.tensor(npt[idxs], dtype=dtype, device=device)
443

444
        def set_numpy(tensor, indices, value):
445
            if not isinstance(value, int):
446
                if self.device_type != 'cpu':
447
                    value = value.cpu()
448
                value = value.numpy()
449

450
            npt, idxs = tensor_indices_to_np(tensor, indices)
451
            npt[idxs] = value
452
            return npt
453

454
        def assert_get_eq(tensor, indexer):
455
            self.assertEqual(tensor[indexer], get_numpy(tensor, indexer))
456

457
        def assert_set_eq(tensor, indexer, val):
458
            pyt = tensor.clone()
459
            numt = tensor.clone()
460
            pyt[indexer] = val
461
            numt = torch.tensor(set_numpy(numt, indexer, val), dtype=dtype, device=device)
462
            self.assertEqual(pyt, numt)
463

464
        def assert_backward_eq(tensor, indexer):
465
            cpu = tensor.float().clone().detach().requires_grad_(True)
466
            outcpu = cpu[indexer]
467
            gOcpu = torch.rand_like(outcpu)
468
            outcpu.backward(gOcpu)
469
            dev = cpu.to(device).detach().requires_grad_(True)
470
            outdev = dev[indexer]
471
            outdev.backward(gOcpu.to(device))
472
            self.assertEqual(cpu.grad, dev.grad)
473

474
        def get_set_tensor(indexed, indexer):
475
            set_size = indexed[indexer].size()
476
            set_count = indexed[indexer].numel()
477
            set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
478
            return set_tensor
479

480
        # Tensor is  0  1  2  3  4
481
        #            5  6  7  8  9
482
        #           10 11 12 13 14
483
        #           15 16 17 18 19
484
        reference = torch.arange(0., 20, dtype=dtype, device=device).view(4, 5)
485

486
        indices_to_test = [
487
            # grab the second, fourth columns
488
            [slice(None), [1, 3]],
489

490
            # first, third rows,
491
            [[0, 2], slice(None)],
492

493
            # weird shape
494
            [slice(None), [[0, 1],
495
                           [2, 3]]],
496
            # negatives
497
            [[-1], [0]],
498
            [[0, 2], [-1]],
499
            [slice(None), [-1]],
500
        ]
501

502
        # only test dupes on gets
503
        get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
504

505
        for indexer in get_indices_to_test:
506
            assert_get_eq(reference, indexer)
507
            if self.device_type != 'cpu':
508
                assert_backward_eq(reference, indexer)
509

510
        for indexer in indices_to_test:
511
            assert_set_eq(reference, indexer, 44)
512
            assert_set_eq(reference,
513
                          indexer,
514
                          get_set_tensor(reference, indexer))
515

516
        reference = torch.arange(0., 160, dtype=dtype, device=device).view(4, 8, 5)
517

518
        indices_to_test = [
519
            [slice(None), slice(None), [0, 3, 4]],
520
            [slice(None), [2, 4, 5, 7], slice(None)],
521
            [[2, 3], slice(None), slice(None)],
522
            [slice(None), [0, 2, 3], [1, 3, 4]],
523
            [slice(None), [0], [1, 2, 4]],
524
            [slice(None), [0, 1, 3], [4]],
525
            [slice(None), [[0, 1], [1, 0]], [[2, 3]]],
526
            [slice(None), [[0, 1], [2, 3]], [[0]]],
527
            [slice(None), [[5, 6]], [[0, 3], [4, 4]]],
528
            [[0, 2, 3], [1, 3, 4], slice(None)],
529
            [[0], [1, 2, 4], slice(None)],
530
            [[0, 1, 3], [4], slice(None)],
531
            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
532
            [[[0, 1], [1, 0]], [[2, 3]], slice(None)],
533
            [[[0, 1], [2, 3]], [[0]], slice(None)],
534
            [[[2, 1]], [[0, 3], [4, 4]], slice(None)],
535
            [[[2]], [[0, 3], [4, 1]], slice(None)],
536
            # non-contiguous indexing subspace
537
            [[0, 2, 3], slice(None), [1, 3, 4]],
538
            # [...]
539
            # less dim, ellipsis
540
            [[0, 2], ],
541
            [[0, 2], slice(None)],
542
            [[0, 2], Ellipsis],
543
            [[0, 2], slice(None), Ellipsis],
544
            [[0, 2], Ellipsis, slice(None)],
545
            [[0, 2], [1, 3]],
546
            [[0, 2], [1, 3], Ellipsis],
547
            [Ellipsis, [1, 3], [2, 3]],
548
            [Ellipsis, [2, 3, 4]],
549
            [Ellipsis, slice(None), [2, 3, 4]],
550
            [slice(None), Ellipsis, [2, 3, 4]],
551

552
            # ellipsis counts for nothing
553
            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
554
            [slice(None), Ellipsis, slice(None), [0, 3, 4]],
555
            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
556
            [slice(None), slice(None), [0, 3, 4], Ellipsis],
557
            [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
558
            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
559
            [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
560
        ]
561

562
        for indexer in indices_to_test:
563
            assert_get_eq(reference, indexer)
564
            assert_set_eq(reference, indexer, 212)
565
            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
566
            if torch.cuda.is_available():
567
                assert_backward_eq(reference, indexer)
568

569
        reference = torch.arange(0., 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
570

571
        indices_to_test = [
572
            [slice(None), slice(None), slice(None), [0, 3, 4]],
573
            [slice(None), slice(None), [2, 4, 5, 7], slice(None)],
574
            [slice(None), [2, 3], slice(None), slice(None)],
575
            [[1, 2], slice(None), slice(None), slice(None)],
576
            [slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
577
            [slice(None), slice(None), [0], [1, 2, 4]],
578
            [slice(None), slice(None), [0, 1, 3], [4]],
579
            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
580
            [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
581
            [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
582
            [slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
583
            [slice(None), [0], [1, 2, 4], slice(None)],
584
            [slice(None), [0, 1, 3], [4], slice(None)],
585
            [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
586
            [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
587
            [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
588
            [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
589
            [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
590
            [[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
591
            [[0], [1, 2, 4], slice(None), slice(None)],
592
            [[0, 1, 2], [4], slice(None), slice(None)],
593
            [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
594
            [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
595
            [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
596
            [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
597
            [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
598
            [slice(None), [2, 3, 4], [1, 3, 4], [4]],
599
            [slice(None), [0, 1, 3], [4], [1, 3, 4]],
600
            [slice(None), [6], [0, 2, 3], [1, 3, 4]],
601
            [slice(None), [2, 3, 5], [3], [4]],
602
            [slice(None), [0], [4], [1, 3, 4]],
603
            [slice(None), [6], [0, 2, 3], [1]],
604
            [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
605
            [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
606
            [[2, 0, 1], [1, 2, 3], [4], slice(None)],
607
            [[0, 1, 2], [4], [1, 3, 4], slice(None)],
608
            [[0], [0, 2, 3], [1, 3, 4], slice(None)],
609
            [[0, 2, 1], [3], [4], slice(None)],
610
            [[0], [4], [1, 3, 4], slice(None)],
611
            [[1], [0, 2, 3], [1], slice(None)],
612
            [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
613

614
            # less dim, ellipsis
615
            [Ellipsis, [0, 3, 4]],
616
            [Ellipsis, slice(None), [0, 3, 4]],
617
            [Ellipsis, slice(None), slice(None), [0, 3, 4]],
618
            [slice(None), Ellipsis, [0, 3, 4]],
619
            [slice(None), slice(None), Ellipsis, [0, 3, 4]],
620
            [slice(None), [0, 2, 3], [1, 3, 4]],
621
            [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
622
            [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
623
            [[0], [1, 2, 4]],
624
            [[0], [1, 2, 4], slice(None)],
625
            [[0], [1, 2, 4], Ellipsis],
626
            [[0], [1, 2, 4], Ellipsis, slice(None)],
627
            [[1], ],
628
            [[0, 2, 1], [3], [4]],
629
            [[0, 2, 1], [3], [4], slice(None)],
630
            [[0, 2, 1], [3], [4], Ellipsis],
631
            [Ellipsis, [0, 2, 1], [3], [4]],
632
        ]
633

634
        for indexer in indices_to_test:
635
            assert_get_eq(reference, indexer)
636
            assert_set_eq(reference, indexer, 1333)
637
            assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
638
        indices_to_test += [
639
            [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
640
            [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
641
        ]
642
        for indexer in indices_to_test:
643
            assert_get_eq(reference, indexer)
644
            assert_set_eq(reference, indexer, 1333)
645
            if self.device_type != 'cpu':
646
                assert_backward_eq(reference, indexer)
647

648
    def test_advancedindex_big(self, device):
649
        reference = torch.arange(0, 123344, dtype=torch.int, device=device)
650

651
        self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
652
                         torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
653

654
    def test_set_item_to_scalar_tensor(self, device):
655
        m = random.randint(1, 10)
656
        n = random.randint(1, 10)
657
        z = torch.randn([m, n], device=device)
658
        a = 1.0
659
        w = torch.tensor(a, requires_grad=True, device=device)
660
        z[:, 0] = w
661
        z.sum().backward()
662
        self.assertEqual(w.grad, m * a)
663

664
    def test_single_int(self, device):
665
        v = torch.randn(5, 7, 3, device=device)
666
        self.assertEqual(v[4].shape, (7, 3))
667

668
    def test_multiple_int(self, device):
669
        v = torch.randn(5, 7, 3, device=device)
670
        self.assertEqual(v[4].shape, (7, 3))
671
        self.assertEqual(v[4, :, 1].shape, (7,))
672

673
    def test_none(self, device):
674
        v = torch.randn(5, 7, 3, device=device)
675
        self.assertEqual(v[None].shape, (1, 5, 7, 3))
676
        self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
677
        self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
678
        self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
679

680
    def test_step(self, device):
681
        v = torch.arange(10, device=device)
682
        self.assertEqual(v[::1], v)
683
        self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
684
        self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
685
        self.assertEqual(v[::11].tolist(), [0])
686
        self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
687

688
    def test_step_assignment(self, device):
689
        v = torch.zeros(4, 4, device=device)
690
        v[0, 1::2] = torch.tensor([3., 4.], device=device)
691
        self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
692
        self.assertEqual(v[1:].sum(), 0)
693

694
    def test_bool_indices(self, device):
695
        v = torch.randn(5, 7, 3, device=device)
696
        boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool, device=device)
697
        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
698
        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
699

700
        v = torch.tensor([True, False, True], dtype=torch.bool, device=device)
701
        boolIndices = torch.tensor([True, False, False], dtype=torch.bool, device=device)
702
        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8, device=device)
703
        with warnings.catch_warnings(record=True) as w:
704
            self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
705
            self.assertEqual(v[boolIndices], v[uint8Indices])
706
            self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool, device=device))
707
            self.assertEqual(len(w), 2)
708

709
    def test_bool_indices_accumulate(self, device):
710
        mask = torch.zeros(size=(10, ), dtype=torch.bool, device=device)
711
        y = torch.ones(size=(10, 10), device=device)
712
        y.index_put_((mask, ), y[mask], accumulate=True)
713
        self.assertEqual(y, torch.ones(size=(10, 10), device=device))
714

715
    def test_multiple_bool_indices(self, device):
716
        v = torch.randn(5, 7, 3, device=device)
717
        # note: these broadcast together and are transposed to the first dim
718
        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool, device=device)
719
        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
720
        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
721

722
    def test_byte_mask(self, device):
723
        v = torch.randn(5, 7, 3, device=device)
724
        mask = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
725
        with warnings.catch_warnings(record=True) as w:
726
            self.assertEqual(v[mask].shape, (3, 7, 3))
727
            self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
728
            self.assertEqual(len(w), 2)
729

730
        v = torch.tensor([1.], device=device)
731
        self.assertEqual(v[v == 0], torch.tensor([], device=device))
732

733
    def test_byte_mask_accumulate(self, device):
734
        mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
735
        y = torch.ones(size=(10, 10), device=device)
736
        with warnings.catch_warnings(record=True) as w:
737
            warnings.simplefilter("always")
738
            y.index_put_((mask, ), y[mask], accumulate=True)
739
            self.assertEqual(y, torch.ones(size=(10, 10), device=device))
740
            self.assertEqual(len(w), 2)
741

742
    @skipIfTorchDynamo("This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472")
743
    def test_index_put_accumulate_large_tensor(self, device):
744
        # This test is for tensors with number of elements >= INT_MAX (2^31 - 1).
745
        N = (1 << 31) + 5
746
        dt = torch.int8
747
        a = torch.ones(N, dtype=dt, device=device)
748
        indices = torch.tensor([-2, 0, -2, -1, 0, -1, 1], device=device, dtype=torch.long)
749
        values = torch.tensor([6, 5, 6, 6, 5, 7, 11], dtype=dt, device=device)
750

751
        a.index_put_((indices, ), values, accumulate=True)
752

753
        self.assertEqual(a[0], 11)
754
        self.assertEqual(a[1], 12)
755
        self.assertEqual(a[2], 1)
756
        self.assertEqual(a[-3], 1)
757
        self.assertEqual(a[-2], 13)
758
        self.assertEqual(a[-1], 14)
759

760
        a = torch.ones((2, N), dtype=dt, device=device)
761
        indices0 = torch.tensor([0, -1, 0, 1], device=device, dtype=torch.long)
762
        indices1 = torch.tensor([-2, -1, 0, 1], device=device, dtype=torch.long)
763
        values = torch.tensor([12, 13, 10, 11], dtype=dt, device=device)
764

765
        a.index_put_((indices0, indices1), values, accumulate=True)
766

767
        self.assertEqual(a[0, 0], 11)
768
        self.assertEqual(a[0, 1], 1)
769
        self.assertEqual(a[1, 0], 1)
770
        self.assertEqual(a[1, 1], 12)
771
        self.assertEqual(a[:, 2], torch.ones(2, dtype=torch.int8))
772
        self.assertEqual(a[:, -3], torch.ones(2, dtype=torch.int8))
773
        self.assertEqual(a[0, -2], 13)
774
        self.assertEqual(a[1, -2], 1)
775
        self.assertEqual(a[-1, -1], 14)
776
        self.assertEqual(a[0, -1], 1)
777

778
    @onlyNativeDeviceTypes
779
    def test_index_put_accumulate_expanded_values(self, device):
780
        # checks the issue with cuda: https://github.com/pytorch/pytorch/issues/39227
781
        # and verifies consistency with CPU result
782
        t = torch.zeros((5, 2))
783
        t_dev = t.to(device)
784
        indices = [
785
            torch.tensor([0, 1, 2, 3]),
786
            torch.tensor([1, ]),
787
        ]
788
        indices_dev = [i.to(device) for i in indices]
789
        values0d = torch.tensor(1.0)
790
        values1d = torch.tensor([1.0, ])
791

792
        out_cuda = t_dev.index_put_(indices_dev, values0d.to(device), accumulate=True)
793
        out_cpu = t.index_put_(indices, values0d, accumulate=True)
794
        self.assertEqual(out_cuda.cpu(), out_cpu)
795

796
        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
797
        out_cpu = t.index_put_(indices, values1d, accumulate=True)
798
        self.assertEqual(out_cuda.cpu(), out_cpu)
799

800
        t = torch.zeros(4, 3, 2)
801
        t_dev = t.to(device)
802

803
        indices = [
804
            torch.tensor([0, ]),
805
            torch.arange(3)[:, None],
806
            torch.arange(2)[None, :],
807
        ]
808
        indices_dev = [i.to(device) for i in indices]
809
        values1d = torch.tensor([-1.0, -2.0])
810
        values2d = torch.tensor([[-1.0, -2.0], ])
811

812
        out_cuda = t_dev.index_put_(indices_dev, values1d.to(device), accumulate=True)
813
        out_cpu = t.index_put_(indices, values1d, accumulate=True)
814
        self.assertEqual(out_cuda.cpu(), out_cpu)
815

816
        out_cuda = t_dev.index_put_(indices_dev, values2d.to(device), accumulate=True)
817
        out_cpu = t.index_put_(indices, values2d, accumulate=True)
818
        self.assertEqual(out_cuda.cpu(), out_cpu)
819

820
    @onlyCUDA
821
    def test_index_put_accumulate_non_contiguous(self, device):
822
        t = torch.zeros((5, 2, 2))
823
        t_dev = t.to(device)
824
        t1 = t_dev[:, 0, :]
825
        t2 = t[:, 0, :]
826
        self.assertTrue(not t1.is_contiguous())
827
        self.assertTrue(not t2.is_contiguous())
828

829
        indices = [torch.tensor([0, 1]), ]
830
        indices_dev = [i.to(device) for i in indices]
831
        value = torch.randn(2, 2)
832
        out_cuda = t1.index_put_(indices_dev, value.to(device), accumulate=True)
833
        out_cpu = t2.index_put_(indices, value, accumulate=True)
834
        self.assertTrue(not t1.is_contiguous())
835
        self.assertTrue(not t2.is_contiguous())
836

837
        self.assertEqual(out_cuda.cpu(), out_cpu)
838

839
    @onlyCUDA
840
    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
841
    def test_index_put_accumulate_with_optional_tensors(self, device):
842
        # TODO: replace with a better solution.
843
        # Currently, here using torchscript to put None into indices.
844
        # on C++ it gives indices as a list of 2 optional tensors: first is null and
845
        # the second is a valid tensor.
846
        @torch.jit.script
847
        def func(x, i, v):
848
            idx = [None, i]
849
            x.index_put_(idx, v, accumulate=True)
850
            return x
851

852
        n = 4
853
        t = torch.arange(n * 2, dtype=torch.float32).reshape(n, 2)
854
        t_dev = t.to(device)
855
        indices = torch.tensor([1, 0])
856
        indices_dev = indices.to(device)
857
        value0d = torch.tensor(10.0)
858
        value1d = torch.tensor([1.0, 2.0])
859

860
        out_cuda = func(t_dev, indices_dev, value0d.cuda())
861
        out_cpu = func(t, indices, value0d)
862
        self.assertEqual(out_cuda.cpu(), out_cpu)
863

864
        out_cuda = func(t_dev, indices_dev, value1d.cuda())
865
        out_cpu = func(t, indices, value1d)
866
        self.assertEqual(out_cuda.cpu(), out_cpu)
867

868
    @onlyNativeDeviceTypes
869
    def test_index_put_accumulate_duplicate_indices(self, device):
870
        for i in range(1, 512):
871
            # generate indices by random walk, this will create indices with
872
            # lots of duplicates interleaved with each other
873
            delta = torch.empty(i, dtype=torch.double, device=device).uniform_(-1, 1)
874
            indices = delta.cumsum(0).long()
875

876
            input = torch.randn(indices.abs().max() + 1, device=device)
877
            values = torch.randn(indices.size(0), device=device)
878
            output = input.index_put((indices,), values, accumulate=True)
879

880
            input_list = input.tolist()
881
            indices_list = indices.tolist()
882
            values_list = values.tolist()
883
            for i, v in zip(indices_list, values_list):
884
                input_list[i] += v
885

886
            self.assertEqual(output, input_list)
887

888
    @onlyNativeDeviceTypes
889
    def test_index_ind_dtype(self, device):
890
        x = torch.randn(4, 4, device=device)
891
        ind_long = torch.randint(4, (4,), dtype=torch.long, device=device)
892
        ind_int = ind_long.int()
893
        src = torch.randn(4, device=device)
894
        ref = x[ind_long, ind_long]
895
        res = x[ind_int, ind_int]
896
        self.assertEqual(ref, res)
897
        ref = x[ind_long, :]
898
        res = x[ind_int, :]
899
        self.assertEqual(ref, res)
900
        ref = x[:, ind_long]
901
        res = x[:, ind_int]
902
        self.assertEqual(ref, res)
903
        # no repeating indices for index_put
904
        ind_long = torch.arange(4, dtype=torch.long, device=device)
905
        ind_int = ind_long.int()
906
        for accum in (True, False):
907
            inp_ref = x.clone()
908
            inp_res = x.clone()
909
            torch.index_put_(inp_ref, (ind_long, ind_long), src, accum)
910
            torch.index_put_(inp_res, (ind_int, ind_int), src, accum)
911
            self.assertEqual(inp_ref, inp_res)
912

913
    @skipXLA
914
    def test_index_put_accumulate_empty(self, device):
915
        # Regression test for https://github.com/pytorch/pytorch/issues/94667
916
        input = torch.rand([], dtype=torch.float32, device=device)
917
        with self.assertRaises(RuntimeError):
918
            input.index_put([], torch.tensor([1.0], device=device), True)
919

920
    def test_multiple_byte_mask(self, device):
921
        v = torch.randn(5, 7, 3, device=device)
922
        # note: these broadcast together and are transposed to the first dim
923
        mask1 = torch.ByteTensor([1, 0, 1, 1, 0]).to(device)
924
        mask2 = torch.ByteTensor([1, 1, 1]).to(device)
925
        with warnings.catch_warnings(record=True) as w:
926
            warnings.simplefilter("always")
927
            self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
928
            self.assertEqual(len(w), 2)
929

930
    def test_byte_mask2d(self, device):
931
        v = torch.randn(5, 7, 3, device=device)
932
        c = torch.randn(5, 7, device=device)
933
        num_ones = (c > 0).sum()
934
        r = v[c > 0]
935
        self.assertEqual(r.shape, (num_ones, 3))
936

937
    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
938
    def test_jit_indexing(self, device):
939
        def fn1(x):
940
            x[x < 50] = 1.0
941
            return x
942

943
        def fn2(x):
944
            x[0:50] = 1.0
945
            return x
946

947
        scripted_fn1 = torch.jit.script(fn1)
948
        scripted_fn2 = torch.jit.script(fn2)
949
        data = torch.arange(100, device=device, dtype=torch.float)
950
        out = scripted_fn1(data.detach().clone())
951
        ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
952
        self.assertEqual(out, ref)
953
        out = scripted_fn2(data.detach().clone())
954
        self.assertEqual(out, ref)
955

956
    def test_int_indices(self, device):
957
        v = torch.randn(5, 7, 3, device=device)
958
        self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
959
        self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
960
        self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
961

962
    @dtypes(torch.cfloat, torch.cdouble, torch.float, torch.bfloat16, torch.long, torch.bool)
963
    @dtypesIfCPU(torch.cfloat, torch.cdouble, torch.float, torch.long, torch.bool, torch.bfloat16)
964
    @dtypesIfCUDA(torch.cfloat, torch.cdouble, torch.half, torch.long, torch.bool, torch.bfloat16)
965
    def test_index_put_src_datatype(self, device, dtype):
966
        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
967
        vals = torch.ones(3, 2, 4, device=device, dtype=dtype)
968
        indices = (torch.tensor([0, 2, 1]),)
969
        res = src.index_put_(indices, vals, accumulate=True)
970
        self.assertEqual(res.shape, src.shape)
971

972
    @dtypes(torch.float, torch.bfloat16, torch.long, torch.bool)
973
    @dtypesIfCPU(torch.float, torch.long, torch.bfloat16, torch.bool)
974
    @dtypesIfCUDA(torch.half, torch.long, torch.bfloat16, torch.bool)
975
    def test_index_src_datatype(self, device, dtype):
976
        src = torch.ones(3, 2, 4, device=device, dtype=dtype)
977
        # test index
978
        res = src[[0, 2, 1], :, :]
979
        self.assertEqual(res.shape, src.shape)
980
        # test index_put, no accum
981
        src[[0, 2, 1], :, :] = res
982
        self.assertEqual(res.shape, src.shape)
983

984
    def test_int_indices2d(self, device):
985
        # From the NumPy indexing example
986
        x = torch.arange(0, 12, device=device).view(4, 3)
987
        rows = torch.tensor([[0, 0], [3, 3]], device=device)
988
        columns = torch.tensor([[0, 2], [0, 2]], device=device)
989
        self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
990

991
    def test_int_indices_broadcast(self, device):
992
        # From the NumPy indexing example
993
        x = torch.arange(0, 12, device=device).view(4, 3)
994
        rows = torch.tensor([0, 3], device=device)
995
        columns = torch.tensor([0, 2], device=device)
996
        result = x[rows[:, None], columns]
997
        self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
998

999
    def test_empty_index(self, device):
1000
        x = torch.arange(0, 12, device=device).view(4, 3)
1001
        idx = torch.tensor([], dtype=torch.long, device=device)
1002
        self.assertEqual(x[idx].numel(), 0)
1003

1004
        # empty assignment should have no effect but not throw an exception
1005
        y = x.clone()
1006
        y[idx] = -1
1007
        self.assertEqual(x, y)
1008

1009
        mask = torch.zeros(4, 3, device=device).bool()
1010
        y[mask] = -1
1011
        self.assertEqual(x, y)
1012

1013
    def test_empty_ndim_index(self, device):
1014
        x = torch.randn(5, device=device)
1015
        self.assertEqual(torch.empty(0, 2, device=device), x[torch.empty(0, 2, dtype=torch.int64, device=device)])
1016

1017
        x = torch.randn(2, 3, 4, 5, device=device)
1018
        self.assertEqual(torch.empty(2, 0, 6, 4, 5, device=device),
1019
                         x[:, torch.empty(0, 6, dtype=torch.int64, device=device)])
1020

1021
        x = torch.empty(10, 0, device=device)
1022
        self.assertEqual(x[[1, 2]].shape, (2, 0))
1023
        self.assertEqual(x[[], []].shape, (0,))
1024
        with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
1025
            x[:, [0, 1]]
1026

1027
    def test_empty_ndim_index_bool(self, device):
1028
        x = torch.randn(5, device=device)
1029
        self.assertRaises(IndexError, lambda: x[torch.empty(0, 2, dtype=torch.uint8, device=device)])
1030

1031
    def test_empty_slice(self, device):
1032
        x = torch.randn(2, 3, 4, 5, device=device)
1033
        y = x[:, :, :, 1]
1034
        z = y[:, 1:1, :]
1035
        self.assertEqual((2, 0, 4), z.shape)
1036
        # this isn't technically necessary, but matches NumPy stride calculations.
1037
        self.assertEqual((60, 20, 5), z.stride())
1038
        self.assertTrue(z.is_contiguous())
1039

1040
    def test_index_getitem_copy_bools_slices(self, device):
1041
        true = torch.tensor(1, dtype=torch.uint8, device=device)
1042
        false = torch.tensor(0, dtype=torch.uint8, device=device)
1043

1044
        tensors = [torch.randn(2, 3, device=device), torch.tensor(3., device=device)]
1045

1046
        for a in tensors:
1047
            self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
1048
            self.assertEqual(torch.empty(0, *a.shape), a[False])
1049
            self.assertNotEqual(a.data_ptr(), a[true].data_ptr())
1050
            self.assertEqual(torch.empty(0, *a.shape), a[false])
1051
            self.assertEqual(a.data_ptr(), a[None].data_ptr())
1052
            self.assertEqual(a.data_ptr(), a[...].data_ptr())
1053

1054
    def test_index_setitem_bools_slices(self, device):
1055
        true = torch.tensor(1, dtype=torch.uint8, device=device)
1056
        false = torch.tensor(0, dtype=torch.uint8, device=device)
1057

1058
        tensors = [torch.randn(2, 3, device=device), torch.tensor(3, device=device)]
1059

1060
        for a in tensors:
1061
            # prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
1062
            # (some of these ops already prefix a 1 to the size)
1063
            neg_ones = torch.ones_like(a) * -1
1064
            neg_ones_expanded = neg_ones.unsqueeze(0).unsqueeze(0)
1065
            a[True] = neg_ones_expanded
1066
            self.assertEqual(a, neg_ones)
1067
            a[False] = 5
1068
            self.assertEqual(a, neg_ones)
1069
            a[true] = neg_ones_expanded * 2
1070
            self.assertEqual(a, neg_ones * 2)
1071
            a[false] = 5
1072
            self.assertEqual(a, neg_ones * 2)
1073
            a[None] = neg_ones_expanded * 3
1074
            self.assertEqual(a, neg_ones * 3)
1075
            a[...] = neg_ones_expanded * 4
1076
            self.assertEqual(a, neg_ones * 4)
1077
            if a.dim() == 0:
1078
                with self.assertRaises(IndexError):
1079
                    a[:] = neg_ones_expanded * 5
1080

1081
    def test_index_scalar_with_bool_mask(self, device):
1082
        a = torch.tensor(1, device=device)
1083
        uintMask = torch.tensor(True, dtype=torch.uint8, device=device)
1084
        boolMask = torch.tensor(True, dtype=torch.bool, device=device)
1085
        self.assertEqual(a[uintMask], a[boolMask])
1086
        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1087

1088
        a = torch.tensor(True, dtype=torch.bool, device=device)
1089
        self.assertEqual(a[uintMask], a[boolMask])
1090
        self.assertEqual(a[uintMask].dtype, a[boolMask].dtype)
1091

1092
    def test_setitem_expansion_error(self, device):
1093
        true = torch.tensor(True, device=device)
1094
        a = torch.randn(2, 3, device=device)
1095
        # check prefix with  non-1s doesn't work
1096
        a_expanded = a.expand(torch.Size([5, 1]) + a.size())
1097
        # NumPy: ValueError
1098
        with self.assertRaises(RuntimeError):
1099
            a[True] = a_expanded
1100
        with self.assertRaises(RuntimeError):
1101
            a[true] = a_expanded
1102

1103
    def test_getitem_scalars(self, device):
1104
        zero = torch.tensor(0, dtype=torch.int64, device=device)
1105
        one = torch.tensor(1, dtype=torch.int64, device=device)
1106

1107
        # non-scalar indexed with scalars
1108
        a = torch.randn(2, 3, device=device)
1109
        self.assertEqual(a[0], a[zero])
1110
        self.assertEqual(a[0][1], a[zero][one])
1111
        self.assertEqual(a[0, 1], a[zero, one])
1112
        self.assertEqual(a[0, one], a[zero, 1])
1113

1114
        # indexing by a scalar should slice (not copy)
1115
        self.assertEqual(a[0, 1].data_ptr(), a[zero, one].data_ptr())
1116
        self.assertEqual(a[1].data_ptr(), a[one.int()].data_ptr())
1117
        self.assertEqual(a[1].data_ptr(), a[one.short()].data_ptr())
1118

1119
        # scalar indexed with scalar
1120
        r = torch.randn((), device=device)
1121
        with self.assertRaises(IndexError):
1122
            r[:]
1123
        with self.assertRaises(IndexError):
1124
            r[zero]
1125
        self.assertEqual(r, r[...])
1126

1127
    def test_setitem_scalars(self, device):
1128
        zero = torch.tensor(0, dtype=torch.int64)
1129

1130
        # non-scalar indexed with scalars
1131
        a = torch.randn(2, 3, device=device)
1132
        a_set_with_number = a.clone()
1133
        a_set_with_scalar = a.clone()
1134
        b = torch.randn(3, device=device)
1135

1136
        a_set_with_number[0] = b
1137
        a_set_with_scalar[zero] = b
1138
        self.assertEqual(a_set_with_number, a_set_with_scalar)
1139
        a[1, zero] = 7.7
1140
        self.assertEqual(7.7, a[1, 0])
1141

1142
        # scalar indexed with scalars
1143
        r = torch.randn((), device=device)
1144
        with self.assertRaises(IndexError):
1145
            r[:] = 8.8
1146
        with self.assertRaises(IndexError):
1147
            r[zero] = 8.8
1148
        r[...] = 9.9
1149
        self.assertEqual(9.9, r)
1150

1151
    def test_basic_advanced_combined(self, device):
1152
        # From the NumPy indexing example
1153
        x = torch.arange(0, 12, device=device).view(4, 3)
1154
        self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
1155
        self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
1156

1157
        # Check that it is a copy
1158
        unmodified = x.clone()
1159
        x[1:2, [1, 2]].zero_()
1160
        self.assertEqual(x, unmodified)
1161

1162
        # But assignment should modify the original
1163
        unmodified = x.clone()
1164
        x[1:2, [1, 2]] = 0
1165
        self.assertNotEqual(x, unmodified)
1166

1167
    def test_int_assignment(self, device):
1168
        x = torch.arange(0, 4, device=device).view(2, 2)
1169
        x[1] = 5
1170
        self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
1171

1172
        x = torch.arange(0, 4, device=device).view(2, 2)
1173
        x[1] = torch.arange(5, 7, device=device)
1174
        self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
1175

1176
    def test_byte_tensor_assignment(self, device):
1177
        x = torch.arange(0., 16, device=device).view(4, 4)
1178
        b = torch.ByteTensor([True, False, True, False]).to(device)
1179
        value = torch.tensor([3., 4., 5., 6.], device=device)
1180

1181
        with warnings.catch_warnings(record=True) as w:
1182
            x[b] = value
1183
            self.assertEqual(len(w), 1)
1184

1185
        self.assertEqual(x[0], value)
1186
        self.assertEqual(x[1], torch.arange(4., 8, device=device))
1187
        self.assertEqual(x[2], value)
1188
        self.assertEqual(x[3], torch.arange(12., 16, device=device))
1189

1190
    def test_variable_slicing(self, device):
1191
        x = torch.arange(0, 16, device=device).view(4, 4)
1192
        indices = torch.IntTensor([0, 1]).to(device)
1193
        i, j = indices
1194
        self.assertEqual(x[i:j], x[0:1])
1195

1196
    def test_ellipsis_tensor(self, device):
1197
        x = torch.arange(0, 9, device=device).view(3, 3)
1198
        idx = torch.tensor([0, 2], device=device)
1199
        self.assertEqual(x[..., idx].tolist(), [[0, 2],
1200
                                                [3, 5],
1201
                                                [6, 8]])
1202
        self.assertEqual(x[idx, ...].tolist(), [[0, 1, 2],
1203
                                                [6, 7, 8]])
1204

1205
    def test_unravel_index_errors(self, device):
1206
        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1207
            torch.unravel_index(
1208
                torch.tensor(0.5, device=device),
1209
                (2, 2))
1210

1211
        with self.assertRaisesRegex(TypeError, r"expected 'indices' to be integer"):
1212
            torch.unravel_index(
1213
                torch.tensor([], device=device),
1214
                (10, 3, 5))
1215

1216
        with self.assertRaisesRegex(TypeError, r"expected 'shape' to be int or sequence"):
1217
            torch.unravel_index(
1218
                torch.tensor([1], device=device, dtype=torch.int64),
1219
                torch.tensor([1, 2, 3]))
1220

1221
        with self.assertRaisesRegex(TypeError, r"expected 'shape' sequence to only contain ints"):
1222
            torch.unravel_index(
1223
                torch.tensor([1], device=device, dtype=torch.int64),
1224
                (1, 2, 2.0))
1225

1226
        with self.assertRaisesRegex(ValueError, r"'shape' cannot have negative values, but got \(2, -3\)"):
1227
            torch.unravel_index(
1228
                torch.tensor(0, device=device),
1229
                (2, -3))
1230

1231
    def test_invalid_index(self, device):
1232
        x = torch.arange(0, 16, device=device).view(4, 4)
1233
        self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
1234

1235
    def test_out_of_bound_index(self, device):
1236
        x = torch.arange(0, 100, device=device).view(2, 5, 10)
1237
        self.assertRaisesRegex(IndexError, 'index 5 is out of bounds for dimension 1 with size 5', lambda: x[0, 5])
1238
        self.assertRaisesRegex(IndexError, 'index 4 is out of bounds for dimension 0 with size 2', lambda: x[4, 5])
1239
        self.assertRaisesRegex(IndexError, 'index 15 is out of bounds for dimension 2 with size 10',
1240
                               lambda: x[0, 1, 15])
1241
        self.assertRaisesRegex(IndexError, 'index 12 is out of bounds for dimension 2 with size 10',
1242
                               lambda: x[:, :, 12])
1243

1244
    def test_zero_dim_index(self, device):
1245
        x = torch.tensor(10, device=device)
1246
        self.assertEqual(x, x.item())
1247

1248
        def runner():
1249
            print(x[0])
1250
            return x[0]
1251

1252
        self.assertRaisesRegex(IndexError, 'invalid index', runner)
1253

1254
    @onlyCUDA
1255
    def test_invalid_device(self, device):
1256
        idx = torch.tensor([0, 1])
1257
        b = torch.zeros(5, device=device)
1258
        c = torch.tensor([1., 2.], device="cpu")
1259

1260
        for accumulate in [True, False]:
1261
            self.assertRaises(RuntimeError, lambda: torch.index_put_(b, (idx,), c, accumulate=accumulate))
1262

1263
    @onlyCUDA
1264
    def test_cpu_indices(self, device):
1265
        idx = torch.tensor([0, 1])
1266
        b = torch.zeros(2, device=device)
1267
        x = torch.ones(10, device=device)
1268
        x[idx] = b  # index_put_
1269
        ref = torch.ones(10, device=device)
1270
        ref[:2] = 0
1271
        self.assertEqual(x, ref, atol=0, rtol=0)
1272
        out = x[idx]  # index
1273
        self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
1274

1275
    @dtypes(torch.long, torch.float32)
1276
    def test_take_along_dim(self, device, dtype):
1277
        def _test_against_numpy(t, indices, dim):
1278
            actual = torch.take_along_dim(t, indices, dim=dim)
1279
            t_np = t.cpu().numpy()
1280
            indices_np = indices.cpu().numpy()
1281
            expected = np.take_along_axis(t_np, indices_np, axis=dim)
1282
            self.assertEqual(actual, expected, atol=0, rtol=0)
1283

1284
        for shape in [(3, 2), (2, 3, 5), (2, 4, 0), (2, 3, 1, 4)]:
1285
            for noncontiguous in [True, False]:
1286
                t = make_tensor(shape, device=device, dtype=dtype, noncontiguous=noncontiguous)
1287
                for dim in list(range(t.ndim)) + [None]:
1288
                    if dim is None:
1289
                        indices = torch.argsort(t.view(-1))
1290
                    else:
1291
                        indices = torch.argsort(t, dim=dim)
1292

1293
                _test_against_numpy(t, indices, dim)
1294

1295
        # test broadcasting
1296
        t = torch.ones((3, 4, 1), device=device)
1297
        indices = torch.ones((1, 2, 5), dtype=torch.long, device=device)
1298

1299
        _test_against_numpy(t, indices, 1)
1300

1301
        # test empty indices
1302
        t = torch.ones((3, 4, 5), device=device)
1303
        indices = torch.ones((3, 0, 5), dtype=torch.long, device=device)
1304

1305
        _test_against_numpy(t, indices, 1)
1306

1307
    @dtypes(torch.long, torch.float)
1308
    def test_take_along_dim_invalid(self, device, dtype):
1309
        shape = (2, 3, 1, 4)
1310
        dim = 0
1311
        t = make_tensor(shape, device=device, dtype=dtype)
1312
        indices = torch.argsort(t, dim=dim)
1313

1314
        # dim of `t` and `indices` does not match
1315
        with self.assertRaisesRegex(RuntimeError,
1316
                                    "input and indices should have the same number of dimensions"):
1317
            torch.take_along_dim(t, indices[0], dim=0)
1318

1319
        # invalid `indices` dtype
1320
        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1321
            torch.take_along_dim(t, indices.to(torch.bool), dim=0)
1322

1323
        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1324
            torch.take_along_dim(t, indices.to(torch.float), dim=0)
1325

1326
        with self.assertRaisesRegex(RuntimeError, r"dtype of indices should be Long"):
1327
            torch.take_along_dim(t, indices.to(torch.int32), dim=0)
1328

1329
        # invalid axis
1330
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1331
            torch.take_along_dim(t, indices, dim=-7)
1332

1333
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
1334
            torch.take_along_dim(t, indices, dim=7)
1335

1336
    @onlyCUDA
1337
    @dtypes(torch.float)
1338
    def test_gather_take_along_dim_cross_device(self, device, dtype):
1339
        shape = (2, 3, 1, 4)
1340
        dim = 0
1341
        t = make_tensor(shape, device=device, dtype=dtype)
1342
        indices = torch.argsort(t, dim=dim)
1343

1344
        with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
1345
            torch.gather(t, 0, indices.cpu())
1346

1347
        with self.assertRaisesRegex(RuntimeError,
1348
                                    r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"):
1349
            torch.take_along_dim(t, indices.cpu(), dim=0)
1350

1351
        with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
1352
            torch.gather(t.cpu(), 0, indices)
1353

1354
        with self.assertRaisesRegex(RuntimeError,
1355
                                    r"Expected tensor to have .* but got tensor with .* torch.take_along_dim()"):
1356
            torch.take_along_dim(t.cpu(), indices, dim=0)
1357

1358
    @onlyCUDA
1359
    def test_cuda_broadcast_index_use_deterministic_algorithms(self, device):
1360
        with DeterministicGuard(True):
1361
            idx1 = torch.tensor([0])
1362
            idx2 = torch.tensor([2, 6])
1363
            idx3 = torch.tensor([1, 5, 7])
1364

1365
            tensor_a = torch.rand(13, 11, 12, 13, 12).cpu()
1366
            tensor_b = tensor_a.to(device=device)
1367
            tensor_a[idx1] = 1.0
1368
            tensor_a[idx1, :, idx2, idx2, :] = 2.0
1369
            tensor_a[:, idx1, idx3, :, idx3] = 3.0
1370
            tensor_b[idx1] = 1.0
1371
            tensor_b[idx1, :, idx2, idx2, :] = 2.0
1372
            tensor_b[:, idx1, idx3, :, idx3] = 3.0
1373
            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1374

1375
            tensor_a = torch.rand(10, 11).cpu()
1376
            tensor_b = tensor_a.to(device=device)
1377
            tensor_a[idx3] = 1.0
1378
            tensor_a[idx2, :] = 2.0
1379
            tensor_a[:, idx2] = 3.0
1380
            tensor_a[:, idx1] = 4.0
1381
            tensor_b[idx3] = 1.0
1382
            tensor_b[idx2, :] = 2.0
1383
            tensor_b[:, idx2] = 3.0
1384
            tensor_b[:, idx1] = 4.0
1385
            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1386

1387
            tensor_a = torch.rand(10, 10).cpu()
1388
            tensor_b = tensor_a.to(device=device)
1389
            tensor_a[[8]] = 1.0
1390
            tensor_b[[8]] = 1.0
1391
            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1392

1393
            tensor_a = torch.rand(10).cpu()
1394
            tensor_b = tensor_a.to(device=device)
1395
            tensor_a[6] = 1.0
1396
            tensor_b[6] = 1.0
1397
            self.assertEqual(tensor_a, tensor_b.cpu(), atol=0, rtol=0)
1398

1399
    def test_index_limits(self, device):
1400
        #  Regression test for https://github.com/pytorch/pytorch/issues/115415
1401
        t = torch.tensor([], device=device)
1402
        idx_min = torch.iinfo(torch.int64).min
1403
        idx_max = torch.iinfo(torch.int64).max
1404
        self.assertRaises(IndexError, lambda: t[idx_min])
1405
        self.assertRaises(IndexError, lambda: t[idx_max])
1406

1407

1408

1409
# The tests below are from NumPy test_indexing.py with some modifications to
1410
# make them compatible with PyTorch. It's licensed under the BDS license below:
1411
#
1412
# Copyright (c) 2005-2017, NumPy Developers.
1413
# All rights reserved.
1414
#
1415
# Redistribution and use in source and binary forms, with or without
1416
# modification, are permitted provided that the following conditions are
1417
# met:
1418
#
1419
#     * Redistributions of source code must retain the above copyright
1420
#        notice, this list of conditions and the following disclaimer.
1421
#
1422
#     * Redistributions in binary form must reproduce the above
1423
#        copyright notice, this list of conditions and the following
1424
#        disclaimer in the documentation and/or other materials provided
1425
#        with the distribution.
1426
#
1427
#     * Neither the name of the NumPy Developers nor the names of any
1428
#        contributors may be used to endorse or promote products derived
1429
#        from this software without specific prior written permission.
1430
#
1431
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
1432
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
1433
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
1434
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
1435
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
1436
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
1437
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
1438
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
1439
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
1440
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
1441
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1442

1443
class NumpyTests(TestCase):
1444
    def test_index_no_floats(self, device):
1445
        a = torch.tensor([[[5.]]], device=device)
1446

1447
        self.assertRaises(IndexError, lambda: a[0.0])
1448
        self.assertRaises(IndexError, lambda: a[0, 0.0])
1449
        self.assertRaises(IndexError, lambda: a[0.0, 0])
1450
        self.assertRaises(IndexError, lambda: a[0.0, :])
1451
        self.assertRaises(IndexError, lambda: a[:, 0.0])
1452
        self.assertRaises(IndexError, lambda: a[:, 0.0, :])
1453
        self.assertRaises(IndexError, lambda: a[0.0, :, :])
1454
        self.assertRaises(IndexError, lambda: a[0, 0, 0.0])
1455
        self.assertRaises(IndexError, lambda: a[0.0, 0, 0])
1456
        self.assertRaises(IndexError, lambda: a[0, 0.0, 0])
1457
        self.assertRaises(IndexError, lambda: a[-1.4])
1458
        self.assertRaises(IndexError, lambda: a[0, -1.4])
1459
        self.assertRaises(IndexError, lambda: a[-1.4, 0])
1460
        self.assertRaises(IndexError, lambda: a[-1.4, :])
1461
        self.assertRaises(IndexError, lambda: a[:, -1.4])
1462
        self.assertRaises(IndexError, lambda: a[:, -1.4, :])
1463
        self.assertRaises(IndexError, lambda: a[-1.4, :, :])
1464
        self.assertRaises(IndexError, lambda: a[0, 0, -1.4])
1465
        self.assertRaises(IndexError, lambda: a[-1.4, 0, 0])
1466
        self.assertRaises(IndexError, lambda: a[0, -1.4, 0])
1467
        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0])
1468
        # self.assertRaises(IndexError, lambda: a[0.0:, 0.0,:])
1469

1470
    def test_none_index(self, device):
1471
        # `None` index adds newaxis
1472
        a = tensor([1, 2, 3], device=device)
1473
        self.assertEqual(a[None].dim(), a.dim() + 1)
1474

1475
    def test_empty_tuple_index(self, device):
1476
        # Empty tuple index creates a view
1477
        a = tensor([1, 2, 3], device=device)
1478
        self.assertEqual(a[()], a)
1479
        self.assertEqual(a[()].data_ptr(), a.data_ptr())
1480

1481
    def test_empty_fancy_index(self, device):
1482
        # Empty list index creates an empty array
1483
        a = tensor([1, 2, 3], device=device)
1484
        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1485

1486
        b = tensor([], device=device).long()
1487
        self.assertEqual(a[[]], torch.tensor([], dtype=torch.long, device=device))
1488

1489
        b = tensor([], device=device).float()
1490
        self.assertRaises(IndexError, lambda: a[b])
1491

1492
    def test_ellipsis_index(self, device):
1493
        a = tensor([[1, 2, 3],
1494
                    [4, 5, 6],
1495
                    [7, 8, 9]], device=device)
1496
        self.assertIsNot(a[...], a)
1497
        self.assertEqual(a[...], a)
1498
        # `a[...]` was `a` in numpy <1.9.
1499
        self.assertEqual(a[...].data_ptr(), a.data_ptr())
1500

1501
        # Slicing with ellipsis can skip an
1502
        # arbitrary number of dimensions
1503
        self.assertEqual(a[0, ...], a[0])
1504
        self.assertEqual(a[0, ...], a[0, :])
1505
        self.assertEqual(a[..., 0], a[:, 0])
1506

1507
        # In NumPy, slicing with ellipsis results in a 0-dim array. In PyTorch
1508
        # we don't have separate 0-dim arrays and scalars.
1509
        self.assertEqual(a[0, ..., 1], torch.tensor(2, device=device))
1510

1511
        # Assignment with `(Ellipsis,)` on 0-d arrays
1512
        b = torch.tensor(1)
1513
        b[(Ellipsis,)] = 2
1514
        self.assertEqual(b, 2)
1515

1516
    def test_single_int_index(self, device):
1517
        # Single integer index selects one row
1518
        a = tensor([[1, 2, 3],
1519
                    [4, 5, 6],
1520
                    [7, 8, 9]], device=device)
1521

1522
        self.assertEqual(a[0], [1, 2, 3])
1523
        self.assertEqual(a[-1], [7, 8, 9])
1524

1525
        # Index out of bounds produces IndexError
1526
        self.assertRaises(IndexError, a.__getitem__, 1 << 30)
1527
        # Index overflow produces Exception  NB: different exception type
1528
        self.assertRaises(Exception, a.__getitem__, 1 << 64)
1529

1530
    def test_single_bool_index(self, device):
1531
        # Single boolean index
1532
        a = tensor([[1, 2, 3],
1533
                    [4, 5, 6],
1534
                    [7, 8, 9]], device=device)
1535

1536
        self.assertEqual(a[True], a[None])
1537
        self.assertEqual(a[False], a[None][0:0])
1538

1539
    def test_boolean_shape_mismatch(self, device):
1540
        arr = torch.ones((5, 4, 3), device=device)
1541

1542
        index = tensor([True], device=device)
1543
        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
1544

1545
        index = tensor([False] * 6, device=device)
1546
        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
1547

1548
        index = torch.ByteTensor(4, 4).to(device).zero_()
1549
        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
1550
        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
1551

1552
    def test_boolean_indexing_onedim(self, device):
1553
        # Indexing a 2-dimensional array with
1554
        # boolean array of length one
1555
        a = tensor([[0., 0., 0.]], device=device)
1556
        b = tensor([True], device=device)
1557
        self.assertEqual(a[b], a)
1558
        # boolean assignment
1559
        a[b] = 1.
1560
        self.assertEqual(a, tensor([[1., 1., 1.]], device=device))
1561

1562
    def test_boolean_assignment_value_mismatch(self, device):
1563
        # A boolean assignment should fail when the shape of the values
1564
        # cannot be broadcast to the subscription. (see also gh-3458)
1565
        a = torch.arange(0, 4, device=device)
1566

1567
        def f(a, v):
1568
            a[a > -1] = tensor(v).to(device)
1569

1570
        self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [])
1571
        self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [1, 2, 3])
1572
        self.assertRaisesRegex(Exception, 'shape mismatch', f, a[:1], [1, 2, 3])
1573

1574
    def test_boolean_indexing_twodim(self, device):
1575
        # Indexing a 2-dimensional array with
1576
        # 2-dimensional boolean array
1577
        a = tensor([[1, 2, 3],
1578
                    [4, 5, 6],
1579
                    [7, 8, 9]], device=device)
1580
        b = tensor([[True, False, True],
1581
                    [False, True, False],
1582
                    [True, False, True]], device=device)
1583
        self.assertEqual(a[b], tensor([1, 3, 5, 7, 9], device=device))
1584
        self.assertEqual(a[b[1]], tensor([[4, 5, 6]], device=device))
1585
        self.assertEqual(a[b[0]], a[b[2]])
1586

1587
        # boolean assignment
1588
        a[b] = 0
1589
        self.assertEqual(a, tensor([[0, 2, 0],
1590
                                    [4, 0, 6],
1591
                                    [0, 8, 0]], device=device))
1592

1593
    def test_boolean_indexing_weirdness(self, device):
1594
        # Weird boolean indexing things
1595
        a = torch.ones((2, 3, 4), device=device)
1596
        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1597
        self.assertEqual(torch.ones(1, 2, device=device), a[True, [0, 1], True, True, [1], [[2]]])
1598
        self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
1599

1600
    def test_boolean_indexing_weirdness_tensors(self, device):
1601
        # Weird boolean indexing things
1602
        false = torch.tensor(False, device=device)
1603
        true = torch.tensor(True, device=device)
1604
        a = torch.ones((2, 3, 4), device=device)
1605
        self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
1606
        self.assertEqual(torch.ones(1, 2, device=device), a[true, [0, 1], true, true, [1], [[2]]])
1607
        self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
1608

1609
    def test_boolean_indexing_alldims(self, device):
1610
        true = torch.tensor(True, device=device)
1611
        a = torch.ones((2, 3), device=device)
1612
        self.assertEqual((1, 2, 3), a[True, True].shape)
1613
        self.assertEqual((1, 2, 3), a[true, true].shape)
1614

1615
    def test_boolean_list_indexing(self, device):
1616
        # Indexing a 2-dimensional array with
1617
        # boolean lists
1618
        a = tensor([[1, 2, 3],
1619
                    [4, 5, 6],
1620
                    [7, 8, 9]], device=device)
1621
        b = [True, False, False]
1622
        c = [True, True, False]
1623
        self.assertEqual(a[b], tensor([[1, 2, 3]], device=device))
1624
        self.assertEqual(a[b, b], tensor([1], device=device))
1625
        self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]], device=device))
1626
        self.assertEqual(a[c, c], tensor([1, 5], device=device))
1627

1628
    def test_everything_returns_views(self, device):
1629
        # Before `...` would return a itself.
1630
        a = tensor([5], device=device)
1631

1632
        self.assertIsNot(a, a[()])
1633
        self.assertIsNot(a, a[...])
1634
        self.assertIsNot(a, a[:])
1635

1636
    def test_broaderrors_indexing(self, device):
1637
        a = torch.zeros(5, 5, device=device)
1638
        self.assertRaisesRegex(IndexError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
1639
        self.assertRaisesRegex(IndexError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
1640

1641
    def test_trivial_fancy_out_of_bounds(self, device):
1642
        a = torch.zeros(5, device=device)
1643
        ind = torch.ones(20, dtype=torch.int64, device=device)
1644
        if a.is_cuda:
1645
            raise unittest.SkipTest('CUDA asserts instead of raising an exception')
1646
        ind[-1] = 10
1647
        self.assertRaises(IndexError, a.__getitem__, ind)
1648
        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1649
        ind = torch.ones(20, dtype=torch.int64, device=device)
1650
        ind[0] = 11
1651
        self.assertRaises(IndexError, a.__getitem__, ind)
1652
        self.assertRaises(IndexError, a.__setitem__, ind, 0)
1653

1654
    def test_index_is_larger(self, device):
1655
        # Simple case of fancy index broadcasting of the index.
1656
        a = torch.zeros((5, 5), device=device)
1657
        a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.], device=device)
1658

1659
        self.assertTrue((a[:3, :3] == tensor([2., 3., 4.], device=device)).all())
1660

1661
    def test_broadcast_subspace(self, device):
1662
        a = torch.zeros((100, 100), device=device)
1663
        v = torch.arange(0., 100, device=device)[:, None]
1664
        b = torch.arange(99, -1, -1, device=device).long()
1665
        a[b] = v
1666
        expected = b.float().unsqueeze(1).expand(100, 100)
1667
        self.assertEqual(a, expected)
1668

1669
    def test_truncate_leading_1s(self, device):
1670
        col_max = torch.randn(1, 4)
1671
        kernel = col_max.T * col_max  # [4, 4] tensor
1672
        kernel2 = kernel.clone()
1673
        # Set the diagonal
1674
        kernel[range(len(kernel)), range(len(kernel))] = torch.square(col_max)
1675
        torch.diagonal(kernel2).copy_(torch.square(col_max.view(4)))
1676
        self.assertEqual(kernel, kernel2)
1677

1678
instantiate_device_type_tests(TestIndexing, globals(), except_for='meta')
1679
instantiate_device_type_tests(NumpyTests, globals(), except_for='meta')
1680

1681
if __name__ == '__main__':
1682
    run_tests()
1683

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.