pytorch

Форк
0
/
schema.py 
1318 строк · 44.3 Кб
1
## @package schema
2
# Module caffe2.python.schema
3
"""
4
Defines a minimal set of data types that allow to represent datasets with
5
arbitrary nested structure, including objects of variable length, such as
6
maps and lists.
7

8
This defines a columnar storage format for such datasets on top of caffe2
9
tensors. In terms of capacity of representation, it can represent most of
10
the data types supported by Parquet, ORC, DWRF file formats.
11

12
See comments in operator_test/dataset_ops_test.py for an example and
13
walkthrough on how to use schema to store and iterate through a structured
14
in-memory dataset.
15
"""
16

17

18

19

20

21
import logging
22
import numpy as np
23
from caffe2.python import core
24
from caffe2.python import workspace
25
from caffe2.python.core import BlobReference
26
from collections import OrderedDict, namedtuple
27
from past.builtins import basestring
28
from itertools import islice
29
from io import StringIO
30
from typing import Sequence
31

32
logger = logging.getLogger(__name__)
33

34
FIELD_SEPARATOR = ':'
35

36

37
def _join_field_name(prefix, suffix):
38
    if prefix and suffix:
39
        return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
40
    elif prefix:
41
        return prefix
42
    elif suffix:
43
        return suffix
44
    else:
45
        return ''
46

47

48
def _normalize_field(field_or_type_or_blob, keep_blobs=True):
49
    """Clones/normalizes a field before adding it to a container."""
50
    if isinstance(field_or_type_or_blob, Field):
51
        return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
52
    elif type(field_or_type_or_blob) in (type, np.dtype):
53
        return Scalar(dtype=field_or_type_or_blob)
54
    else:
55
        return Scalar(blob=field_or_type_or_blob)
56

57

58
FeatureSpec = namedtuple(
59
    'FeatureSpec',
60
    [
61
        'feature_type',
62
        'feature_names',
63
        'feature_ids',
64
        'feature_is_request_only',
65
        'desired_hash_size',
66
        'feature_to_index',
67
    ]
68
)
69

70
# pyre-fixme[16]: `FeatureSpec.__new__` has no attribute `__defaults__`
71
FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)
72

73

74
class Metadata(
75
    namedtuple(
76
        'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
77
    )
78
):
79
    """Represents additional information associated with a scalar in schema.
80

81
    `categorical_limit` - for fields of integral type that are guaranteed to be
82
    non-negative it specifies the maximum possible value plus one. It's often
83
    used as a size of an embedding table.
84

85
    `expected_value` - anticipated average value of elements in the field.
86
    Usually makes sense for length fields of lists.
87

88
    `feature_specs` - information about the features that contained in this
89
    field. For example if field have more than 1 feature it can have list of
90
    feature names contained in this field."""
91
    __slots__: Sequence[str] = ()
92

93

94
# pyre-fixme[16]: `Metadata.__new__` has no attribute `__defaults__`
95
Metadata.__new__.__defaults__ = (None, None, None)
96

97

98
class Field:
99
    """Represents an abstract field type in a dataset.
100
    """
101

102
    __slots__: Sequence[str] = ("_parent", "_field_offsets")
103

104
    def __init__(self, children):
105
        """Derived classes must call this after their initialization."""
106
        self._parent = (None, 0)
107
        offset = 0
108
        self._field_offsets = []
109
        for child in children:
110
            self._field_offsets.append(offset)
111
            offset += len(child.field_names())
112
        self._field_offsets.append(offset)
113

114
    def clone_schema(self):
115
        return self.clone(keep_blobs=False)
116

117
    def field_names(self):
118
        """Return the children field names for this field."""
119
        raise NotImplementedError('Field is an abstract class.')
120

121
    def field_types(self):
122
        """Return the numpy.dtype for each of the children fields."""
123
        raise NotImplementedError('Field is an abstract class.')
124

125
    def field_metadata(self):
126
        """Return the Metadata for each of the children fields."""
127
        raise NotImplementedError('Field is an abstract class.')
128

129
    def field_blobs(self):
130
        """Return the list of blobs with contents for this Field.
131
        Values can either be all numpy.ndarray or BlobReference.
132
        If any of the fields doesn't have a blob, throws.
133
        """
134
        raise NotImplementedError('Field is an abstract class.')
135

136
    def all_scalars(self):
137
        """Return the list of all Scalar instances in the Field.
138
        The order is the same as for field_names() or field_blobs()"""
139
        raise NotImplementedError('Field is an abstract class.')
140

141
    def has_blobs(self):
142
        """Return True if every scalar of this field has blobs."""
143
        raise NotImplementedError('Field is an abstract class.')
144

145
    def clone(self, keep_blobs=True):
146
        """Clone this Field along with its children."""
147
        raise NotImplementedError('Field is an abstract class.')
148

149
    def _set_parent(self, parent, relative_id):
150
        self._parent = (parent, relative_id)
151

152
    def slice(self):
153
        """
154
        Returns a slice representing the range of field ids that belong to
155
        this field. This slice can be used to index a list of fields.
156

157
        E.g.:
158

159
        >>> s = Struct(
160
        >>>     ('a', Scalar()),
161
        >>>     ('b', Struct(
162
        >>>         ('b1', Scalar()),
163
        >>>         ('b2', Scalar()),
164
        >>>     )),
165
        >>>     ('c', Scalar()),
166
        >>> )
167
        >>> field_data = ['da', 'db1', 'db2', 'dc']
168
        >>> field_data[s.b.split()]
169
        ['db1', 'db2']
170
        """
171
        base_id = self._child_base_id()
172
        return slice(base_id, base_id + len(self.field_names()))
173

174
    def _child_base_id(self, child_index=None):
175
        """Get the base id of the given child"""
176
        p, i = self._parent
177
        pos = 0 if child_index is None else self._field_offsets[child_index]
178
        if p:
179
            pos += p._child_base_id(i)
180
        return pos
181

182
    def __eq__(self, other):
183
        """Equivalance of two schemas"""
184
        return (
185
            (self.field_names() == other.field_names()) and
186
            (self.field_types() == other.field_types()) and
187
            (self.field_metadata() == other.field_metadata())
188
        )
189

190
    def _pprint_impl(self, indent, str_buffer):
191
        raise NotImplementedError('Field is an abstract class.')
192

193
    def __repr__(self):
194
        str_buffer = StringIO()
195
        self._pprint_impl(0, str_buffer)
196
        contents = str_buffer.getvalue()
197
        str_buffer.close()
198
        return contents
199

200

201
class List(Field):
202
    """Represents a variable-length list.
203

204
    Values of a list can also be complex fields such as Lists and Structs.
205
    In addition to the fields exposed by its `values` field, a List exposes an
206
    additional `lengths` field, which will contain the size of each list under
207
    the parent domain.
208
    """
209

210
    __slots__: Sequence[str] = ("lengths", "_items")
211

212
    def __init__(self, values, lengths_blob=None):
213
        if isinstance(lengths_blob, Field):
214
            assert isinstance(lengths_blob, Scalar)
215
            self.lengths = _normalize_field(lengths_blob)
216
        else:
217
            self.lengths = Scalar(np.int32, lengths_blob)
218
        self._items = _normalize_field(values)
219
        self.lengths._set_parent(self, 0)
220
        self._items._set_parent(self, 1)
221
        super().__init__([self.lengths, self._items])
222

223
    def field_names(self):
224
        value_fields = self._items.field_names()
225
        return (
226
            ['lengths'] + [_join_field_name('values', v) for v in value_fields]
227
        )
228

229
    def field_types(self):
230
        return self.lengths.field_types() + self._items.field_types()
231

232
    def field_metadata(self):
233
        return self.lengths.field_metadata() + self._items.field_metadata()
234

235
    def field_blobs(self):
236
        return self.lengths.field_blobs() + self._items.field_blobs()
237

238
    def all_scalars(self):
239
        return self.lengths.all_scalars() + self._items.all_scalars()
240

241
    def has_blobs(self):
242
        return self.lengths.has_blobs() and self._items.has_blobs()
243

244
    def clone(self, keep_blobs=True):
245
        return type(self)(
246
            _normalize_field(self._items, keep_blobs=keep_blobs),
247
            _normalize_field(self.lengths, keep_blobs=keep_blobs)
248
        )
249

250
    def _pprint_impl(self, indent, str_buffer):
251
        str_buffer.write('  ' * indent + "List(\n")
252
        str_buffer.write('  ' * (indent + 1) + "lengths=\n")
253
        self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
254
        str_buffer.write('  ' * (indent + 1) + "_items=\n")
255
        self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
256
        str_buffer.write('  ' * indent + ")\n")
257

258
    def __getattr__(self, item):
259
        """If the value of this list is a struct,
260
        allow to introspect directly into its fields."""
261
        if item.startswith('__'):
262
            raise AttributeError(item)
263
        if isinstance(self._items, Struct):
264
            return getattr(self._items, item)
265
        elif item == 'value' or item == 'items':
266
            return self._items
267
        else:
268
            raise AttributeError('Field not found in list: %s.' % item)
269

270
    def __getitem__(self, item):
271
        names = item.split(FIELD_SEPARATOR, 1)
272

273
        if len(names) == 1:
274
            if item == 'lengths':
275
                return self.lengths
276
            elif item == 'values':
277
                return self._items
278
        else:
279
            if names[0] == 'values':
280
                return self._items[names[1]]
281
        raise KeyError('Field not found in list: %s.' % item)
282

283

284
class ListWithEvicted(List):
285
    """
286
    This class is similar with List, but containing extra field evicted_values for
287
    LRU Hashing.
288
    """
289

290
    __slots__: Sequence[str] = ("_evicted_values",)
291

292
    def __init__(self, values, lengths_blob=None, evicted_values=None):
293
        if isinstance(evicted_values, Field):
294
            assert isinstance(evicted_values, Scalar)
295
            self._evicted_values = _normalize_field(evicted_values)
296
        else:
297
            self._evicted_values = Scalar(np.int64, evicted_values)
298
        super().__init__(values, lengths_blob=lengths_blob)
299

300
    def field_names(self):
301
        value_fields = self._items.field_names()
302
        return (
303
            ['lengths'] + [_join_field_name('values', v) for v in value_fields] + ["_evicted_values"]
304
        )
305

306
    def field_types(self):
307
        return self.lengths.field_types() + self._items.field_types() + self._evicted_values.field_types()
308

309
    def field_metadata(self):
310
        return self.lengths.field_metadata() + self._items.field_metadata() + self._evicted_values.field_metadata()
311

312
    def field_blobs(self):
313
        return self.lengths.field_blobs() + self._items.field_blobs() + self._evicted_values.field_blobs()
314

315
    def all_scalars(self):
316
        return self.lengths.all_scalars() + self._items.all_scalars() + self._evicted_values.all_scalars()
317

318
    def has_blobs(self):
319
        return self.lengths.has_blobs() and self._items.has_blobs() + self._evicted_values.has_blobs()
320

321
    def clone(self, keep_blobs=True):
322
        return type(self)(
323
            _normalize_field(self._items, keep_blobs=keep_blobs),
324
            _normalize_field(self.lengths, keep_blobs=keep_blobs),
325
            _normalize_field(self._evicted_values, keep_blobs=keep_blobs)
326
        )
327

328
    def _pprint_impl(self, indent, str_buffer):
329
        str_buffer.write('  ' * indent + "ListWithEvicted(\n")
330
        str_buffer.write('  ' * (indent + 1) + "lengths=\n")
331
        self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
332
        str_buffer.write('  ' * (indent + 1) + "_items=\n")
333
        self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
334
        str_buffer.write('  ' * (indent + 1) + "_evicted_values=\n")
335
        self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
336
        str_buffer.write('  ' * indent + ")\n")
337

338

339
    def __getattr__(self, item):
340
        """If the value of this list is a struct,
341
        allow to introspect directly into its fields."""
342
        if item.startswith('__'):
343
            raise AttributeError(item)
344
        if item == "_evicted_values":
345
            return self._evicted_values
346
        if isinstance(self._items, Struct):
347
            return getattr(self._items, item)
348
        elif item == 'value' or item == 'items':
349
            return self._items
350
        else:
351
            raise AttributeError('Field not found in list: %s.' % item)
352

353
    def __getitem__(self, item):
354
        names = item.split(FIELD_SEPARATOR, 1)
355

356
        if len(names) == 1:
357
            if item == 'lengths':
358
                return self.lengths
359
            elif item == 'values':
360
                return self._items
361
            elif item == '_evicted_values':
362
                return self._evicted_values
363
        else:
364
            if names[0] == 'values':
365
                return self._items[names[1]]
366
        raise KeyError('Field not found in list: %s.' % item)
367

368

369
class Struct(Field):
370
    """Represents a named list of fields sharing the same domain.
371
    """
372

373
    __slots__: Sequence[str] = ("fields", "_frozen")
374

375
    def __init__(self, *fields):
376
        """ fields is a list of tuples in format of (name, field). The name is
377
        a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
378

379
        Struct(
380
          ('a', Scalar()),
381
          ('b:c', Scalar()),
382
          ('b:d:e', Scalar()),
383
          ('b', Struct(
384
            ('f', Scalar()),
385
          )),
386
        )
387

388
        is equal to
389

390
        Struct(
391
          ('a', Scalar()),
392
          ('b', Struct(
393
            ('c', Scalar()),
394
            ('d', Struct(('e', Scalar()))),
395
            ('f', Scalar()),
396
          )),
397
        )
398
        """
399
        for field in fields:
400
            assert len(field) == 2
401
            assert field[0], 'Field names cannot be empty'
402
            assert field[0] != 'lengths', (
403
                'Struct cannot contain a field named `lengths`.'
404
            )
405
        fields = [(name, _normalize_field(field)) for name, field in fields]
406
        self.fields = OrderedDict()
407
        for name, field in fields:
408
            if FIELD_SEPARATOR in name:
409
                name, field = self._struct_from_nested_name(name, field)
410
            if name not in self.fields:
411
                self.fields[name] = field
412
                continue
413
            if (
414
                    not isinstance(field, Struct) or
415
                    not isinstance(self.fields[name], Struct)
416
            ):
417
                raise ValueError('Duplicate field name: %s' % name)
418
            self.fields[name] = self.fields[name] + field
419
        for id, (_, field) in enumerate(self.fields.items()):
420
            field._set_parent(self, id)
421
        super().__init__(self.fields.values())
422
        self._frozen = True
423

424
    def _struct_from_nested_name(self, nested_name, field):
425
        def create_internal(nested_name, field):
426
            names = nested_name.split(FIELD_SEPARATOR, 1)
427
            if len(names) == 1:
428
                added_field = field
429
            else:
430
                added_field = create_internal(names[1], field)
431
            return Struct((names[0], added_field))
432

433
        names = nested_name.split(FIELD_SEPARATOR, 1)
434
        assert len(names) >= 2
435
        return names[0], create_internal(names[1], field)
436

437
    def get_children(self):
438
        return list(self.fields.items())
439

440
    def field_names(self):
441
        names = []
442
        for name, field in self.fields.items():
443
            names += [_join_field_name(name, f) for f in field.field_names()]
444
        return names
445

446
    def field_types(self):
447
        types = []
448
        for field in self.fields.values():
449
            types += field.field_types()
450
        return types
451

452
    def field_metadata(self):
453
        metadata = []
454
        for field in self.fields.values():
455
            metadata += field.field_metadata()
456
        return metadata
457

458
    def field_blobs(self):
459
        blobs = []
460
        for field in self.fields.values():
461
            blobs += field.field_blobs()
462
        return blobs
463

464
    def all_scalars(self):
465
        scalars = []
466
        for field in self.fields.values():
467
            scalars += field.all_scalars()
468
        return scalars
469

470
    def has_blobs(self):
471
        return all(field.has_blobs() for field in self.fields.values())
472

473
    def clone(self, keep_blobs=True):
474
        normalized_fields = [
475
            (k, _normalize_field(v, keep_blobs=keep_blobs))
476
            for k, v in self.fields.items()
477
        ]
478
        return type(self)(*normalized_fields)
479

480
    def _get_field_by_nested_name(self, nested_name):
481
        names = nested_name.split(FIELD_SEPARATOR, 1)
482
        field = self.fields.get(names[0], None)
483

484
        if field is None:
485
            return None
486

487
        if len(names) == 1:
488
            return field
489

490
        try:
491
            return field[names[1]]
492
        except (KeyError, TypeError):
493
            return None
494

495
    def _pprint_impl(self, indent, str_buffer):
496
        str_buffer.write('  ' * indent + "Struct( \n")
497
        for name, field in self.fields.items():
498
            str_buffer.write('  ' * (indent + 1) + "{}=".format(name) + "\n")
499
            field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
500
        str_buffer.write('  ' * indent + ") \n")
501

502
    def __contains__(self, item):
503
        field = self._get_field_by_nested_name(item)
504
        return field is not None
505

506
    def __len__(self):
507
        return len(self.fields)
508

509
    def __getitem__(self, item):
510
        """
511
        item can be a tuple or list of ints or strings, or a single
512
        int or string. String item is a nested field name, e.g., "a", "a:b",
513
        "a:b:c". Int item is the index of a field at the first level of the
514
        Struct.
515
        """
516
        if isinstance(item, list) or isinstance(item, tuple):
517
            keys = list(self.fields.keys())
518
            return Struct(
519
                * [
520
                    (
521
                        keys[k]
522
                        if isinstance(k, int) else k, self[k]
523
                    ) for k in item
524
                ]
525
            )
526
        elif isinstance(item, int):
527
            return next(islice(self.fields.values(), item, None))
528
        else:
529
            field = self._get_field_by_nested_name(item)
530
            if field is None:
531
                raise KeyError('field "%s" not found' % (item))
532
            return field
533

534
    def get(self, item, default_value):
535
        """
536
        similar to python's dictionary get method, return field of item if found
537
        (i.e. self.item is valid) or otherwise return default_value
538

539
        it's a syntax suger of python's builtin getattr method
540
        """
541
        return getattr(self, item, default_value)
542

543
    def __getattr__(self, item):
544
        if item.startswith('__'):
545
            raise AttributeError(item)
546
        try:
547
            return super().__getattribute__("fields")[item]
548
        except KeyError as e:
549
            raise AttributeError(item) from e
550

551
    def __setattr__(self, key, value):
552
        # Disable setting attributes after initialization to prevent false
553
        # impression of being able to overwrite a field.
554
        # Allowing setting internal states mainly so that _parent can be set
555
        # post initialization.
556
        if getattr(self, '_frozen', None) and not key.startswith('_'):
557
            raise TypeError('Struct.__setattr__() is disabled after __init__()')
558
        super().__setattr__(key, value)
559

560
    def __add__(self, other):
561
        """
562
        Allows to merge fields of two schema.Struct using '+' operator.
563
        If two Struct have common field names, the merge is conducted
564
        recursively. Here are examples:
565

566
        Example 1
567
        s1 = Struct(('a', Scalar()))
568
        s2 = Struct(('b', Scalar()))
569
        s1 + s2 == Struct(
570
            ('a', Scalar()),
571
            ('b', Scalar()),
572
        )
573

574
        Example 2
575
        s1 = Struct(
576
            ('a', Scalar()),
577
            ('b', Struct(('c', Scalar()))),
578
        )
579
        s2 = Struct(('b', Struct(('d', Scalar()))))
580
        s1 + s2 == Struct(
581
            ('a', Scalar()),
582
            ('b', Struct(
583
                ('c', Scalar()),
584
                ('d', Scalar()),
585
            )),
586
        )
587
        """
588
        if not isinstance(other, Struct):
589
            return NotImplemented
590

591
        children = OrderedDict(self.get_children())
592
        for name, right_field in other.get_children():
593
            if name not in children:
594
                children[name] = right_field
595
                continue
596
            left_field = children[name]
597
            if not (isinstance(left_field, Struct) and isinstance(right_field, Struct)):
598
                raise TypeError(
599
                    "Type of left_field, " + str(type(left_field)) +
600
                    ", and type of right_field, " +
601
                    str(type(right_field)) +
602
                    ", must both the Struct to allow merging of the field, " + name)
603
            children[name] = left_field + right_field
604

605
        return Struct(*(children.items()))
606

607
    def __sub__(self, other):
608
        """
609
        Allows to remove common fields of two schema.Struct from self by
610
        using '-' operator. If two Struct have common field names, the
611
        removal is conducted recursively. If a child struct has no fields
612
        inside, it will be removed from its parent. Here are examples:
613

614
        Example 1
615
        s1 = Struct(
616
            ('a', Scalar()),
617
            ('b', Scalar()),
618
        )
619
        s2 = Struct(('a', Scalar()))
620
        s1 - s2 == Struct(('b', Scalar()))
621

622
        Example 2
623
        s1 = Struct(
624
            ('b', Struct(
625
                ('c', Scalar()),
626
                ('d', Scalar()),
627
            ))
628
        )
629
        s2 = Struct(
630
            ('b', Struct(('c', Scalar()))),
631
        )
632
        s1 - s2 == Struct(
633
            ('b', Struct(
634
                ('d', Scalar()),
635
            )),
636
        )
637

638
        Example 3
639
        s1 = Struct(
640
            ('a', Scalar()),
641
            ('b', Struct(
642
                ('d', Scalar()),
643
            ))
644
        )
645
        s2 = Struct(
646
            ('b', Struct(
647
                ('c', Scalar())
648
                ('d', Scalar())
649
            )),
650
        )
651
        s1 - s2 == Struct(
652
            ('a', Scalar()),
653
        )
654
        """
655
        if not isinstance(other, Struct):
656
            return NotImplemented
657

658
        children = OrderedDict(self.get_children())
659
        for name, right_field in other.get_children():
660
            if name in children:
661
                left_field = children[name]
662
                if type(left_field) == type(right_field):
663
                    if isinstance(left_field, Struct):
664
                        child = left_field - right_field
665
                        if child.get_children():
666
                            children[name] = child
667
                            continue
668
                    children.pop(name)
669
                else:
670
                    raise TypeError(
671
                        "Type of left_field, " + str(type(left_field)) +
672
                        ", is not the same as that of right_field, " +
673
                        str(type(right_field)) +
674
                        ", yet they have the same field name, " + name)
675
        return Struct(*(children.items()))
676

677

678
class Scalar(Field):
679
    """Represents a typed scalar or tensor of fixed shape.
680

681
    A Scalar is a leaf in a schema tree, translating to exactly one tensor in
682
    the dataset's underlying storage.
683

684
    Usually, the tensor storing the actual values of this field is a 1D tensor,
685
    representing a series of values in its domain. It is possible however to
686
    have higher rank values stored as a Scalar, as long as all entries have
687
    the same shape.
688

689
    E.g.:
690

691
        Scalar(np.float64)
692

693
            Scalar field of type float64. Caffe2 will expect readers and
694
            datasets to expose it as a 1D tensor of doubles (vector), where
695
            the size of the vector is determined by this fields' domain.
696

697
        Scalar((np.int32, 5))
698

699
            Tensor field of type int32. Caffe2 will expect readers and
700
            datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
701
            where L is determined by this fields' domain.
702

703
        Scalar((str, (10, 20)))
704

705
            Tensor field of type str. Caffe2 will expect readers and
706
            datasets to implement it as a 3D tensor of shape (L, 10, 20),
707
            where L is determined by this fields' domain.
708

709
    If the field type is unknown at construction time, call Scalar(), that will
710
    default to np.void as its dtype.
711

712
    It is an error to pass a structured dtype to Scalar, since it would contain
713
    more than one field. Instead, use from_dtype, which will construct
714
    a nested `Struct` field reflecting the given dtype's structure.
715

716
    A Scalar can also contain a blob, which represents the value of this
717
    Scalar. A blob can be either a numpy.ndarray, in which case it contain the
718
    actual contents of the Scalar, or a BlobReference, which represents a
719
    blob living in a caffe2 Workspace. If blob of different types are passed,
720
    a conversion to numpy.ndarray is attempted.
721
    """
722

723
    __slots__: Sequence[str] = ("_metadata", "dtype", "_original_dtype", "_blob")
724

725
    def __init__(self, dtype=None, blob=None, metadata=None):
726
        self._metadata = None
727
        self.set(dtype, blob, metadata, unsafe=True)
728
        super().__init__([])
729

730
    def field_names(self):
731
        return ['']
732

733
    def field_type(self):
734
        return self.dtype
735

736
    def field_types(self):
737
        return [self.dtype]
738

739
    def field_metadata(self):
740
        return [self._metadata]
741

742
    def has_blobs(self):
743
        return self._blob is not None
744

745
    def field_blobs(self):
746
        assert self._blob is not None, 'Value is not set for this field.'
747
        return [self._blob]
748

749
    def all_scalars(self):
750
        return [self]
751

752
    def clone(self, keep_blobs=True):
753
        return Scalar(
754
            dtype=self._original_dtype,
755
            blob=self._blob if keep_blobs else None,
756
            metadata=self._metadata
757
        )
758

759
    def get(self):
760
        """Gets the current blob of this Scalar field."""
761
        assert self._blob is not None, 'Value is not set for this field.'
762
        return self._blob
763

764
    def __call__(self):
765
        """Shortcut for self.get()"""
766
        return self.get()
767

768
    @property
769
    def metadata(self):
770
        return self._metadata
771

772
    def set_metadata(self, value):
773
        assert isinstance(value, Metadata), \
774
            'metadata must be Metadata, got {}'.format(type(value))
775
        self._metadata = value
776
        self._validate_metadata()
777

778
    def _validate_metadata(self):
779
        if self._metadata is None:
780
            return
781
        if (self._metadata.categorical_limit is not None and
782
                self.dtype is not None):
783
            assert np.issubdtype(self.dtype, np.integer), \
784
                "`categorical_limit` can be specified only in integral " + \
785
                "fields but got {}".format(self.dtype)
786

787
    def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
788
        """Sets only the blob field still validating the existing dtype"""
789
        if self.dtype.base != np.void and throw_on_type_mismatch:
790
            assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
791
            assert blob.dtype.base == self.dtype.base, (
792
                "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
793
        self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
794

795
    def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
796
        """Set the type and/or blob of this scalar. See __init__ for details.
797

798
        Args:
799
            dtype: can be any numpy type. If not provided and `blob` is
800
                   provided, it will be inferred. If no argument is provided,
801
                   this Scalar will be of type np.void.
802
            blob:  if provided, can be either a BlobReference or a
803
                   numpy.ndarray. If a value of different type is passed,
804
                   a conversion to numpy.ndarray is attempted. Strings aren't
805
                   accepted, since they can be ambiguous. If you want to pass
806
                   a string, to either BlobReference(blob) or np.array(blob).
807
            metadata: optional instance of Metadata, if provided overrides
808
                      the metadata information of the scalar
809
        """
810
        if not unsafe:
811
            logger.warning(
812
                "Scalar should be considered immutable. Only call Scalar.set() "
813
                "on newly created Scalar with unsafe=True. This will become an "
814
                "error soon."
815
            )
816
        if blob is not None and isinstance(blob, basestring):
817
            raise ValueError(
818
                'Passing str blob to Scalar.set() is ambiguous. '
819
                'Do either set(blob=np.array(blob)) or '
820
                'set(blob=BlobReference(blob))'
821
            )
822

823
        self._original_dtype = dtype
824
        # Numpy will collapse a shape of 1 into an unindexed data array (shape = ()),
825
        # which betrays the docstring of this class (which expects shape = (1,)).
826
        # >>> import numpy as np
827
        # >> np.dtype((np.int32, 1))
828
        # dtype('int32')
829
        # >>> np.dtype((np.int32, 5))
830
        # dtype(('<i4', (5,)))
831
        if dtype is not None and isinstance(dtype, tuple) and dtype[1] == 1:
832
            dtype = (dtype[0], (1,))
833
        if dtype is not None:
834
            if isinstance(dtype, tuple) and dtype[0] == np.void:
835
                raise TypeError(
836
                    "Cannot set the Scalar with type {} for blob {}."
837
                    "If this blob is the output of some operation, "
838
                    "please verify the input of that operation has "
839
                    "proper type.".format(dtype, blob)
840
                )
841
            dtype = np.dtype(dtype)
842
        # If blob is not None and it is not a BlobReference, we assume that
843
        # it is actual tensor data, so we will try to cast it to a numpy array.
844
        if blob is not None and not isinstance(blob, BlobReference):
845
            preserve_shape = isinstance(blob, np.ndarray)
846
            if dtype is not None and dtype != np.void:
847
                blob = np.array(blob, dtype=dtype.base)
848
                # if array is empty we may need to reshape a little
849
                if blob.size == 0 and not preserve_shape:
850
                    blob = blob.reshape((0, ) + dtype.shape)
851
            else:
852
                assert isinstance(blob, np.ndarray), (
853
                    'Invalid blob type: %s' % str(type(blob)))
854

855
            # reshape scalars into 1D arrays
856
            # TODO(azzolini): figure out better way of representing this
857
            if len(blob.shape) == 0 and not preserve_shape:
858
                blob = blob.reshape((1, ))
859

860
            # infer inner shape from the blob given
861
            # TODO(dzhulgakov): tweak this to make it work with PackedStruct
862
            if (len(blob.shape) > 1 and dtype is not None and
863
                    dtype.base != np.void):
864
                dtype = np.dtype((dtype.base, blob.shape[1:]))
865
        # if we were still unable to infer the dtype
866
        if dtype is None:
867
            dtype = np.dtype(np.void)
868
        assert not dtype.fields, (
869
            'Cannot create Scalar with a structured dtype. ' +
870
            'Use from_dtype instead.'
871
        )
872
        self.dtype = dtype
873
        self._blob = blob
874
        if metadata is not None:
875
            self.set_metadata(metadata)
876
        self._validate_metadata()
877

878
    def set_type(self, dtype):
879
        self._original_dtype = dtype
880
        if dtype is not None:
881
            self.dtype = np.dtype(dtype)
882
        else:
883
            self.dtype = np.dtype(np.void)
884
        self._validate_metadata()
885

886
    def _pprint_impl(self, indent, str_buffer):
887
        str_buffer.write('  ' * (indent) +
888
            'Scalar({!r}, {!r}, {!r})'.format(
889
            self.dtype, self._blob, self._metadata) + "\n")
890

891
    def id(self):
892
        """
893
        Return the zero-indexed position of this scalar field in its schema.
894
        Used in order to index into the field_blob list returned by readers or
895
        accepted by writers.
896
        """
897
        return self._child_base_id()
898

899

900
def Map(
901
    keys,
902
    values,
903
    keys_name='keys',
904
    values_name='values',
905
    lengths_blob=None
906
):
907
    """A map is a List of Struct containing keys and values fields.
908
    Optionally, you can provide custom name for the key and value fields.
909
    """
910
    return List(
911
        Struct((keys_name, keys), (values_name, values)),
912
        lengths_blob=lengths_blob
913
    )
914

915
def MapWithEvicted(
916
    keys,
917
    values,
918
    keys_name='keys',
919
    values_name='values',
920
    lengths_blob=None,
921
    evicted_values=None
922
):
923
    """A map with extra field evicted_values
924
    """
925
    return ListWithEvicted(
926
        Struct((keys_name, keys), (values_name, values)),
927
        lengths_blob=lengths_blob,
928
        evicted_values=evicted_values
929
    )
930

931

932
def NamedTuple(name_prefix, *fields):
933
    return Struct(* [('%s_%d' % (name_prefix, i), field)
934
                     for i, field in enumerate(fields)])
935

936

937
def Tuple(*fields):
938
    """
939
    Creates a Struct with default, sequential, field names of given types.
940
    """
941
    return NamedTuple('field', *fields)
942

943

944
def RawTuple(num_fields, name_prefix='field'):
945
    """
946
    Creates a tuple of `num_field` untyped scalars.
947
    """
948
    assert isinstance(num_fields, int)
949
    assert num_fields >= 0
950
    return NamedTuple(name_prefix, *([np.void] * num_fields))
951

952

953
def from_dtype(dtype, _outer_shape=()):
954
    """Constructs a Caffe2 schema from the given numpy's dtype.
955

956
    Numpy supports scalar, array-like and structured datatypes, as long as
957
    all the shapes are fixed. This function breaks down the given dtype into
958
    a Caffe2 schema containing `Struct` and `Scalar` types.
959

960
    Fields containing byte offsets are not currently supported.
961
    """
962
    if not isinstance(dtype, np.dtype):
963
        # wrap into a ndtype
964
        shape = _outer_shape
965
        dtype = np.dtype((dtype, _outer_shape))
966
    else:
967
        # concatenate shapes if necessary
968
        shape = _outer_shape + dtype.shape
969
        if shape != dtype.shape:
970
            dtype = np.dtype((dtype.base, shape))
971

972
    if not dtype.fields:
973
        return Scalar(dtype)
974

975
    struct_fields = []
976
    for name, (fdtype, offset) in dtype.fields:
977
        assert offset == 0, ('Fields with byte offsets are not supported.')
978
        struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
979
    return Struct(*struct_fields)
980

981

982
class _SchemaNode:
983
    """This is a private class used to represent a Schema Node"""
984

985
    __slots__: Sequence[str] = ("name", "children", "type_str", "field")
986

987
    def __init__(self, name, type_str=''):
988
        self.name = name
989
        self.children = []
990
        self.type_str = type_str
991
        self.field = None
992

993
    def add_child(self, name, type_str=''):
994
        for child in self.children:
995
            if child.name == name and child.type_str == type_str:
996
                return child
997
        child = _SchemaNode(name, type_str)
998
        self.children.append(child)
999
        return child
1000

1001
    def get_field(self):
1002

1003
        list_names = ['lengths', 'values']
1004
        map_names = ['lengths', 'keys', 'values']
1005

1006
        if len(self.children) == 0 or self.field is not None:
1007
            if self.field is None:
1008
                return Struct()
1009
            else:
1010
                return self.field
1011

1012
        child_names = []
1013
        for child in self.children:
1014
            child_names.append(child.name)
1015

1016
        if (set(child_names) == set(list_names)):
1017
            for child in self.children:
1018
                if child.name == 'values':
1019
                    values_field = child.get_field()
1020
                else:
1021
                    lengths_field = child.get_field()
1022
            self.field = List(
1023
                values_field,
1024
                lengths_blob=lengths_field
1025
            )
1026
            self.type_str = "List"
1027
            return self.field
1028
        elif (set(child_names) == set(map_names)):
1029
            for child in self.children:
1030
                if child.name == 'keys':
1031
                    key_field = child.get_field()
1032
                elif child.name == 'values':
1033
                    values_field = child.get_field()
1034
                else:
1035
                    lengths_field = child.get_field()
1036
            self.field = Map(
1037
                key_field,
1038
                values_field,
1039
                lengths_blob=lengths_field
1040
            )
1041
            self.type_str = "Map"
1042
            return self.field
1043

1044
        else:
1045
            struct_fields = []
1046
            for child in self.children:
1047
                struct_fields.append((child.name, child.get_field()))
1048

1049
            self.field = Struct(*struct_fields)
1050
            self.type_str = "Struct"
1051
            return self.field
1052

1053
    def print_recursively(self):
1054
        for child in self.children:
1055
            child.print_recursively()
1056
        logger.info("Printing node: Name and type")
1057
        logger.info(self.name)
1058
        logger.info(self.type_str)
1059

1060

1061
def from_column_list(
1062
    col_names, col_types=None,
1063
    col_blobs=None, col_metadata=None
1064
):
1065
    """
1066
    Given a list of names, types, and optionally values, construct a Schema.
1067
    """
1068
    if col_types is None:
1069
        col_types = [None] * len(col_names)
1070
    if col_metadata is None:
1071
        col_metadata = [None] * len(col_names)
1072
    if col_blobs is None:
1073
        col_blobs = [None] * len(col_names)
1074
    assert len(col_names) == len(col_types), (
1075
        'col_names and col_types must have the same length.'
1076
    )
1077
    assert len(col_names) == len(col_metadata), (
1078
        'col_names and col_metadata must have the same length.'
1079
    )
1080
    assert len(col_names) == len(col_blobs), (
1081
        'col_names and col_blobs must have the same length.'
1082
    )
1083
    root = _SchemaNode('root', 'Struct')
1084
    for col_name, col_type, col_blob, col_md in zip(
1085
        col_names, col_types, col_blobs, col_metadata
1086
    ):
1087
        columns = col_name.split(FIELD_SEPARATOR)
1088
        current = root
1089
        for i in range(len(columns)):
1090
            name = columns[i]
1091
            type_str = ''
1092
            field = None
1093
            if i == len(columns) - 1:
1094
                type_str = col_type
1095
                field = Scalar(
1096
                    dtype=col_type,
1097
                    blob=col_blob,
1098
                    metadata=col_md
1099
                )
1100
            next = current.add_child(name, type_str)
1101
            if field is not None:
1102
                next.field = field
1103
            current = next
1104

1105
    return root.get_field()
1106

1107

1108
def from_blob_list(schema, values, throw_on_type_mismatch=False):
1109
    """
1110
    Create a schema that clones the given schema, but containing the given
1111
    list of values.
1112
    """
1113
    assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
1114
    if isinstance(values, BlobReference):
1115
        values = [values]
1116
    record = schema.clone_schema()
1117
    scalars = record.all_scalars()
1118
    assert len(scalars) == len(values), (
1119
        'Values must have %d elements, got %d.' % (len(scalars), len(values))
1120
    )
1121
    for scalar, value in zip(scalars, values):
1122
        scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
1123
    return record
1124

1125

1126
def as_record(value):
1127
    if isinstance(value, Field):
1128
        return value
1129
    elif isinstance(value, list) or isinstance(value, tuple):
1130
        is_field_list = all(
1131
            f is tuple and len(f) == 2 and isinstance(f[0], basestring)
1132
            for f in value
1133
        )
1134
        if is_field_list:
1135
            return Struct(* [(k, as_record(v)) for k, v in value])
1136
        else:
1137
            return Tuple(* [as_record(f) for f in value])
1138
    elif isinstance(value, dict):
1139
        return Struct(* [(k, as_record(v)) for k, v in value.items()])
1140
    else:
1141
        return _normalize_field(value)
1142

1143

1144
def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1145
    """
1146
    Given a record containing BlobReferences, return a new record with same
1147
    schema, containing numpy arrays, fetched from the current active workspace.
1148
    """
1149

1150
    def fetch(v):
1151
        if ws is None:
1152
            return workspace.FetchBlob(str(v))
1153
        else:
1154
            return ws.blobs[str(v)].fetch()
1155

1156
    assert isinstance(blob_record, Field)
1157
    field_blobs = blob_record.field_blobs()
1158
    assert all(isinstance(v, BlobReference) for v in field_blobs)
1159
    field_arrays = [fetch(value) for value in field_blobs]
1160
    return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1161

1162

1163
def FeedRecord(blob_record, arrays, ws=None):
1164
    """
1165
    Given a Record containing blob_references and arrays, which is either
1166
    a list of numpy arrays or a Record containing numpy arrays, feeds the
1167
    record to the current workspace.
1168
    """
1169

1170
    def feed(b, v):
1171
        if ws is None:
1172
            workspace.FeedBlob(str(b), v)
1173
        else:
1174
            ws.create_blob(str(b))
1175
            ws.blobs[str(b)].feed(v)
1176
    assert isinstance(blob_record, Field)
1177
    field_blobs = blob_record.field_blobs()
1178
    assert all(isinstance(v, BlobReference) for v in field_blobs)
1179
    if isinstance(arrays, Field):
1180
        # TODO: check schema
1181
        arrays = arrays.field_blobs()
1182
    assert len(arrays) == len(field_blobs), (
1183
        'Values must contain exactly %d ndarrays.' % len(field_blobs)
1184
    )
1185
    for blob, array in zip(field_blobs, arrays):
1186
        feed(blob, array)
1187

1188

1189
def NewRecord(net, schema):
1190
    """
1191
    Given a record of np.arrays, create a BlobReference for each one of them,
1192
    returning a record containing BlobReferences. The name of each returned blob
1193
    is NextScopedBlob(field_name), which guarantees unique name in the current
1194
    net. Use NameScope explicitly to avoid name conflictions between different
1195
    nets.
1196
    """
1197
    if isinstance(schema, Scalar):
1198
        result = schema.clone()
1199
        result.set_value(
1200
            blob=net.NextScopedBlob('unnamed_scalar'),
1201
            unsafe=True,
1202
        )
1203
        return result
1204

1205
    assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
1206
    blob_refs = [
1207
        net.NextScopedBlob(prefix=name)
1208
        for name in schema.field_names()
1209
    ]
1210
    return from_blob_list(schema, blob_refs)
1211

1212

1213
def ConstRecord(net, array_record):
1214
    """
1215
    Given a record of arrays, returns a record of blobs,
1216
    initialized with net.Const.
1217
    """
1218
    blob_record = NewRecord(net, array_record)
1219
    for blob, array in zip(
1220
        blob_record.field_blobs(), array_record.field_blobs()
1221
    ):
1222
        net.Const(array, blob)
1223
    return blob_record
1224

1225

1226
def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1227
    if not schema_or_record.has_blobs():
1228
        record = NewRecord(net, schema_or_record)
1229
    else:
1230
        record = schema_or_record
1231

1232
    for blob_type, blob in zip(record.field_types(), record.field_blobs()):
1233
        try:
1234
            data_type = data_type_for_dtype(blob_type)
1235
            shape = [0] + list(blob_type.shape)
1236
            net.ConstantFill([], blob, shape=shape, dtype=data_type)
1237
        except TypeError:
1238
            logger.warning("Blob {} has type error".format(blob))
1239
            # If data_type_for_dtype doesn't know how to resolve given numpy
1240
            # type to core.DataType, that function can throw type error (for
1241
            # example that would happen for cases of unknown types such as
1242
            # np.void). This is not a problem for cases when the record if going
1243
            # to be overwritten by some operator later, though it might be an
1244
            # issue for type/shape inference.
1245
            if enforce_types:
1246
                raise
1247
            # If we don't enforce types for all items we'll create a blob with
1248
            # the default ConstantFill (FLOAT, no shape)
1249
            net.ConstantFill([], blob, shape=[0])
1250

1251
    return record
1252

1253

1254
_DATA_TYPE_FOR_DTYPE = [
1255
    (str, core.DataType.STRING),
1256
    (np.float16, core.DataType.FLOAT16),
1257
    (np.float32, core.DataType.FLOAT),
1258
    (np.float64, core.DataType.DOUBLE),
1259
    (bool, core.DataType.BOOL),
1260
    (np.int8, core.DataType.INT8),
1261
    (np.int16, core.DataType.INT16),
1262
    (np.int32, core.DataType.INT32),
1263
    (np.int64, core.DataType.INT64),
1264
    (np.uint8, core.DataType.UINT8),
1265
    (np.uint16, core.DataType.UINT16),
1266
]
1267

1268

1269
def is_schema_subset(schema, original_schema):
1270
    # TODO add more checks
1271
    return set(schema.field_names()).issubset(
1272
        set(original_schema.field_names()))
1273

1274
def equal_schemas(schema,
1275
                  original_schema,
1276
                  check_field_names=True,
1277
                  check_field_types=True,
1278
                  check_field_metas=False):
1279
    assert isinstance(schema, Field)
1280
    assert isinstance(original_schema, Field)
1281

1282
    if check_field_names and (
1283
            schema.field_names() != original_schema.field_names()):
1284
        return False
1285
    if check_field_types and (
1286
            schema.field_types() != original_schema.field_types()):
1287
        return False
1288
    if check_field_metas and (
1289
            schema.field_metadata() != original_schema.field_metadata()):
1290
        return False
1291

1292
    return True
1293

1294

1295
def schema_check(schema, previous=None):
1296
    record = as_record(schema)
1297
    if previous is not None:
1298
        assert equal_schemas(schema, previous)
1299
    return record
1300

1301

1302
def data_type_for_dtype(dtype):
1303
    for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1304
        if dtype.base == np_type:
1305
            return dt
1306
    raise TypeError('Unknown dtype: ' + str(dtype.base))
1307

1308

1309
def dtype_for_core_type(core_type):
1310
    for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1311
        if dt == core_type:
1312
            return np_type
1313
    raise TypeError('Unknown core type: ' + str(core_type))
1314

1315

1316
def attach_metadata_to_scalars(field, metadata):
1317
    for f in field.all_scalars():
1318
        f.set_metadata(metadata)
1319

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.