4
Defines a minimal set of data types that allow to represent datasets with
5
arbitrary nested structure, including objects of variable length, such as
8
This defines a columnar storage format for such datasets on top of caffe2
9
tensors. In terms of capacity of representation, it can represent most of
10
the data types supported by Parquet, ORC, DWRF file formats.
12
See comments in operator_test/dataset_ops_test.py for an example and
13
walkthrough on how to use schema to store and iterate through a structured
23
from caffe2.python import core
24
from caffe2.python import workspace
25
from caffe2.python.core import BlobReference
26
from collections import OrderedDict, namedtuple
27
from past.builtins import basestring
28
from itertools import islice
29
from io import StringIO
30
from typing import Sequence
32
logger = logging.getLogger(__name__)
37
def _join_field_name(prefix, suffix):
39
return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
48
def _normalize_field(field_or_type_or_blob, keep_blobs=True):
49
"""Clones/normalizes a field before adding it to a container."""
50
if isinstance(field_or_type_or_blob, Field):
51
return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
52
elif type(field_or_type_or_blob) in (type, np.dtype):
53
return Scalar(dtype=field_or_type_or_blob)
55
return Scalar(blob=field_or_type_or_blob)
58
FeatureSpec = namedtuple(
64
'feature_is_request_only',
71
FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)
76
'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
79
"""Represents additional information associated with a scalar in schema.
81
`categorical_limit` - for fields of integral type that are guaranteed to be
82
non-negative it specifies the maximum possible value plus one. It's often
83
used as a size of an embedding table.
85
`expected_value` - anticipated average value of elements in the field.
86
Usually makes sense for length fields of lists.
88
`feature_specs` - information about the features that contained in this
89
field. For example if field have more than 1 feature it can have list of
90
feature names contained in this field."""
91
__slots__: Sequence[str] = ()
95
Metadata.__new__.__defaults__ = (None, None, None)
99
"""Represents an abstract field type in a dataset.
102
__slots__: Sequence[str] = ("_parent", "_field_offsets")
104
def __init__(self, children):
105
"""Derived classes must call this after their initialization."""
106
self._parent = (None, 0)
108
self._field_offsets = []
109
for child in children:
110
self._field_offsets.append(offset)
111
offset += len(child.field_names())
112
self._field_offsets.append(offset)
114
def clone_schema(self):
115
return self.clone(keep_blobs=False)
117
def field_names(self):
118
"""Return the children field names for this field."""
119
raise NotImplementedError('Field is an abstract class.')
121
def field_types(self):
122
"""Return the numpy.dtype for each of the children fields."""
123
raise NotImplementedError('Field is an abstract class.')
125
def field_metadata(self):
126
"""Return the Metadata for each of the children fields."""
127
raise NotImplementedError('Field is an abstract class.')
129
def field_blobs(self):
130
"""Return the list of blobs with contents for this Field.
131
Values can either be all numpy.ndarray or BlobReference.
132
If any of the fields doesn't have a blob, throws.
134
raise NotImplementedError('Field is an abstract class.')
136
def all_scalars(self):
137
"""Return the list of all Scalar instances in the Field.
138
The order is the same as for field_names() or field_blobs()"""
139
raise NotImplementedError('Field is an abstract class.')
142
"""Return True if every scalar of this field has blobs."""
143
raise NotImplementedError('Field is an abstract class.')
145
def clone(self, keep_blobs=True):
146
"""Clone this Field along with its children."""
147
raise NotImplementedError('Field is an abstract class.')
149
def _set_parent(self, parent, relative_id):
150
self._parent = (parent, relative_id)
154
Returns a slice representing the range of field ids that belong to
155
this field. This slice can be used to index a list of fields.
162
>>> ('b1', Scalar()),
163
>>> ('b2', Scalar()),
167
>>> field_data = ['da', 'db1', 'db2', 'dc']
168
>>> field_data[s.b.split()]
171
base_id = self._child_base_id()
172
return slice(base_id, base_id + len(self.field_names()))
174
def _child_base_id(self, child_index=None):
175
"""Get the base id of the given child"""
177
pos = 0 if child_index is None else self._field_offsets[child_index]
179
pos += p._child_base_id(i)
182
def __eq__(self, other):
183
"""Equivalance of two schemas"""
185
(self.field_names() == other.field_names()) and
186
(self.field_types() == other.field_types()) and
187
(self.field_metadata() == other.field_metadata())
190
def _pprint_impl(self, indent, str_buffer):
191
raise NotImplementedError('Field is an abstract class.')
194
str_buffer = StringIO()
195
self._pprint_impl(0, str_buffer)
196
contents = str_buffer.getvalue()
202
"""Represents a variable-length list.
204
Values of a list can also be complex fields such as Lists and Structs.
205
In addition to the fields exposed by its `values` field, a List exposes an
206
additional `lengths` field, which will contain the size of each list under
210
__slots__: Sequence[str] = ("lengths", "_items")
212
def __init__(self, values, lengths_blob=None):
213
if isinstance(lengths_blob, Field):
214
assert isinstance(lengths_blob, Scalar)
215
self.lengths = _normalize_field(lengths_blob)
217
self.lengths = Scalar(np.int32, lengths_blob)
218
self._items = _normalize_field(values)
219
self.lengths._set_parent(self, 0)
220
self._items._set_parent(self, 1)
221
super().__init__([self.lengths, self._items])
223
def field_names(self):
224
value_fields = self._items.field_names()
226
['lengths'] + [_join_field_name('values', v) for v in value_fields]
229
def field_types(self):
230
return self.lengths.field_types() + self._items.field_types()
232
def field_metadata(self):
233
return self.lengths.field_metadata() + self._items.field_metadata()
235
def field_blobs(self):
236
return self.lengths.field_blobs() + self._items.field_blobs()
238
def all_scalars(self):
239
return self.lengths.all_scalars() + self._items.all_scalars()
242
return self.lengths.has_blobs() and self._items.has_blobs()
244
def clone(self, keep_blobs=True):
246
_normalize_field(self._items, keep_blobs=keep_blobs),
247
_normalize_field(self.lengths, keep_blobs=keep_blobs)
250
def _pprint_impl(self, indent, str_buffer):
251
str_buffer.write(' ' * indent + "List(\n")
252
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
253
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
254
str_buffer.write(' ' * (indent + 1) + "_items=\n")
255
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
256
str_buffer.write(' ' * indent + ")\n")
258
def __getattr__(self, item):
259
"""If the value of this list is a struct,
260
allow to introspect directly into its fields."""
261
if item.startswith('__'):
262
raise AttributeError(item)
263
if isinstance(self._items, Struct):
264
return getattr(self._items, item)
265
elif item == 'value' or item == 'items':
268
raise AttributeError('Field not found in list: %s.' % item)
270
def __getitem__(self, item):
271
names = item.split(FIELD_SEPARATOR, 1)
274
if item == 'lengths':
276
elif item == 'values':
279
if names[0] == 'values':
280
return self._items[names[1]]
281
raise KeyError('Field not found in list: %s.' % item)
284
class ListWithEvicted(List):
286
This class is similar with List, but containing extra field evicted_values for
290
__slots__: Sequence[str] = ("_evicted_values",)
292
def __init__(self, values, lengths_blob=None, evicted_values=None):
293
if isinstance(evicted_values, Field):
294
assert isinstance(evicted_values, Scalar)
295
self._evicted_values = _normalize_field(evicted_values)
297
self._evicted_values = Scalar(np.int64, evicted_values)
298
super().__init__(values, lengths_blob=lengths_blob)
300
def field_names(self):
301
value_fields = self._items.field_names()
303
['lengths'] + [_join_field_name('values', v) for v in value_fields] + ["_evicted_values"]
306
def field_types(self):
307
return self.lengths.field_types() + self._items.field_types() + self._evicted_values.field_types()
309
def field_metadata(self):
310
return self.lengths.field_metadata() + self._items.field_metadata() + self._evicted_values.field_metadata()
312
def field_blobs(self):
313
return self.lengths.field_blobs() + self._items.field_blobs() + self._evicted_values.field_blobs()
315
def all_scalars(self):
316
return self.lengths.all_scalars() + self._items.all_scalars() + self._evicted_values.all_scalars()
319
return self.lengths.has_blobs() and self._items.has_blobs() + self._evicted_values.has_blobs()
321
def clone(self, keep_blobs=True):
323
_normalize_field(self._items, keep_blobs=keep_blobs),
324
_normalize_field(self.lengths, keep_blobs=keep_blobs),
325
_normalize_field(self._evicted_values, keep_blobs=keep_blobs)
328
def _pprint_impl(self, indent, str_buffer):
329
str_buffer.write(' ' * indent + "ListWithEvicted(\n")
330
str_buffer.write(' ' * (indent + 1) + "lengths=\n")
331
self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
332
str_buffer.write(' ' * (indent + 1) + "_items=\n")
333
self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
334
str_buffer.write(' ' * (indent + 1) + "_evicted_values=\n")
335
self._evicted_values._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
336
str_buffer.write(' ' * indent + ")\n")
339
def __getattr__(self, item):
340
"""If the value of this list is a struct,
341
allow to introspect directly into its fields."""
342
if item.startswith('__'):
343
raise AttributeError(item)
344
if item == "_evicted_values":
345
return self._evicted_values
346
if isinstance(self._items, Struct):
347
return getattr(self._items, item)
348
elif item == 'value' or item == 'items':
351
raise AttributeError('Field not found in list: %s.' % item)
353
def __getitem__(self, item):
354
names = item.split(FIELD_SEPARATOR, 1)
357
if item == 'lengths':
359
elif item == 'values':
361
elif item == '_evicted_values':
362
return self._evicted_values
364
if names[0] == 'values':
365
return self._items[names[1]]
366
raise KeyError('Field not found in list: %s.' % item)
370
"""Represents a named list of fields sharing the same domain.
373
__slots__: Sequence[str] = ("fields", "_frozen")
375
def __init__(self, *fields):
376
""" fields is a list of tuples in format of (name, field). The name is
377
a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
394
('d', Struct(('e', Scalar()))),
400
assert len(field) == 2
401
assert field[0], 'Field names cannot be empty'
402
assert field[0] != 'lengths', (
403
'Struct cannot contain a field named `lengths`.'
405
fields = [(name, _normalize_field(field)) for name, field in fields]
406
self.fields = OrderedDict()
407
for name, field in fields:
408
if FIELD_SEPARATOR in name:
409
name, field = self._struct_from_nested_name(name, field)
410
if name not in self.fields:
411
self.fields[name] = field
414
not isinstance(field, Struct) or
415
not isinstance(self.fields[name], Struct)
417
raise ValueError('Duplicate field name: %s' % name)
418
self.fields[name] = self.fields[name] + field
419
for id, (_, field) in enumerate(self.fields.items()):
420
field._set_parent(self, id)
421
super().__init__(self.fields.values())
424
def _struct_from_nested_name(self, nested_name, field):
425
def create_internal(nested_name, field):
426
names = nested_name.split(FIELD_SEPARATOR, 1)
430
added_field = create_internal(names[1], field)
431
return Struct((names[0], added_field))
433
names = nested_name.split(FIELD_SEPARATOR, 1)
434
assert len(names) >= 2
435
return names[0], create_internal(names[1], field)
437
def get_children(self):
438
return list(self.fields.items())
440
def field_names(self):
442
for name, field in self.fields.items():
443
names += [_join_field_name(name, f) for f in field.field_names()]
446
def field_types(self):
448
for field in self.fields.values():
449
types += field.field_types()
452
def field_metadata(self):
454
for field in self.fields.values():
455
metadata += field.field_metadata()
458
def field_blobs(self):
460
for field in self.fields.values():
461
blobs += field.field_blobs()
464
def all_scalars(self):
466
for field in self.fields.values():
467
scalars += field.all_scalars()
471
return all(field.has_blobs() for field in self.fields.values())
473
def clone(self, keep_blobs=True):
474
normalized_fields = [
475
(k, _normalize_field(v, keep_blobs=keep_blobs))
476
for k, v in self.fields.items()
478
return type(self)(*normalized_fields)
480
def _get_field_by_nested_name(self, nested_name):
481
names = nested_name.split(FIELD_SEPARATOR, 1)
482
field = self.fields.get(names[0], None)
491
return field[names[1]]
492
except (KeyError, TypeError):
495
def _pprint_impl(self, indent, str_buffer):
496
str_buffer.write(' ' * indent + "Struct( \n")
497
for name, field in self.fields.items():
498
str_buffer.write(' ' * (indent + 1) + "{}=".format(name) + "\n")
499
field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
500
str_buffer.write(' ' * indent + ") \n")
502
def __contains__(self, item):
503
field = self._get_field_by_nested_name(item)
504
return field is not None
507
return len(self.fields)
509
def __getitem__(self, item):
511
item can be a tuple or list of ints or strings, or a single
512
int or string. String item is a nested field name, e.g., "a", "a:b",
513
"a:b:c". Int item is the index of a field at the first level of the
516
if isinstance(item, list) or isinstance(item, tuple):
517
keys = list(self.fields.keys())
522
if isinstance(k, int) else k, self[k]
526
elif isinstance(item, int):
527
return next(islice(self.fields.values(), item, None))
529
field = self._get_field_by_nested_name(item)
531
raise KeyError('field "%s" not found' % (item))
534
def get(self, item, default_value):
536
similar to python's dictionary get method, return field of item if found
537
(i.e. self.item is valid) or otherwise return default_value
539
it's a syntax suger of python's builtin getattr method
541
return getattr(self, item, default_value)
543
def __getattr__(self, item):
544
if item.startswith('__'):
545
raise AttributeError(item)
547
return super().__getattribute__("fields")[item]
548
except KeyError as e:
549
raise AttributeError(item) from e
551
def __setattr__(self, key, value):
556
if getattr(self, '_frozen', None) and not key.startswith('_'):
557
raise TypeError('Struct.__setattr__() is disabled after __init__()')
558
super().__setattr__(key, value)
560
def __add__(self, other):
562
Allows to merge fields of two schema.Struct using '+' operator.
563
If two Struct have common field names, the merge is conducted
564
recursively. Here are examples:
567
s1 = Struct(('a', Scalar()))
568
s2 = Struct(('b', Scalar()))
577
('b', Struct(('c', Scalar()))),
579
s2 = Struct(('b', Struct(('d', Scalar()))))
588
if not isinstance(other, Struct):
589
return NotImplemented
591
children = OrderedDict(self.get_children())
592
for name, right_field in other.get_children():
593
if name not in children:
594
children[name] = right_field
596
left_field = children[name]
597
if not (isinstance(left_field, Struct) and isinstance(right_field, Struct)):
599
"Type of left_field, " + str(type(left_field)) +
600
", and type of right_field, " +
601
str(type(right_field)) +
602
", must both the Struct to allow merging of the field, " + name)
603
children[name] = left_field + right_field
605
return Struct(*(children.items()))
607
def __sub__(self, other):
609
Allows to remove common fields of two schema.Struct from self by
610
using '-' operator. If two Struct have common field names, the
611
removal is conducted recursively. If a child struct has no fields
612
inside, it will be removed from its parent. Here are examples:
619
s2 = Struct(('a', Scalar()))
620
s1 - s2 == Struct(('b', Scalar()))
630
('b', Struct(('c', Scalar()))),
655
if not isinstance(other, Struct):
656
return NotImplemented
658
children = OrderedDict(self.get_children())
659
for name, right_field in other.get_children():
661
left_field = children[name]
662
if type(left_field) == type(right_field):
663
if isinstance(left_field, Struct):
664
child = left_field - right_field
665
if child.get_children():
666
children[name] = child
671
"Type of left_field, " + str(type(left_field)) +
672
", is not the same as that of right_field, " +
673
str(type(right_field)) +
674
", yet they have the same field name, " + name)
675
return Struct(*(children.items()))
679
"""Represents a typed scalar or tensor of fixed shape.
681
A Scalar is a leaf in a schema tree, translating to exactly one tensor in
682
the dataset's underlying storage.
684
Usually, the tensor storing the actual values of this field is a 1D tensor,
685
representing a series of values in its domain. It is possible however to
686
have higher rank values stored as a Scalar, as long as all entries have
693
Scalar field of type float64. Caffe2 will expect readers and
694
datasets to expose it as a 1D tensor of doubles (vector), where
695
the size of the vector is determined by this fields' domain.
697
Scalar((np.int32, 5))
699
Tensor field of type int32. Caffe2 will expect readers and
700
datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
701
where L is determined by this fields' domain.
703
Scalar((str, (10, 20)))
705
Tensor field of type str. Caffe2 will expect readers and
706
datasets to implement it as a 3D tensor of shape (L, 10, 20),
707
where L is determined by this fields' domain.
709
If the field type is unknown at construction time, call Scalar(), that will
710
default to np.void as its dtype.
712
It is an error to pass a structured dtype to Scalar, since it would contain
713
more than one field. Instead, use from_dtype, which will construct
714
a nested `Struct` field reflecting the given dtype's structure.
716
A Scalar can also contain a blob, which represents the value of this
717
Scalar. A blob can be either a numpy.ndarray, in which case it contain the
718
actual contents of the Scalar, or a BlobReference, which represents a
719
blob living in a caffe2 Workspace. If blob of different types are passed,
720
a conversion to numpy.ndarray is attempted.
723
__slots__: Sequence[str] = ("_metadata", "dtype", "_original_dtype", "_blob")
725
def __init__(self, dtype=None, blob=None, metadata=None):
726
self._metadata = None
727
self.set(dtype, blob, metadata, unsafe=True)
730
def field_names(self):
733
def field_type(self):
736
def field_types(self):
739
def field_metadata(self):
740
return [self._metadata]
743
return self._blob is not None
745
def field_blobs(self):
746
assert self._blob is not None, 'Value is not set for this field.'
749
def all_scalars(self):
752
def clone(self, keep_blobs=True):
754
dtype=self._original_dtype,
755
blob=self._blob if keep_blobs else None,
756
metadata=self._metadata
760
"""Gets the current blob of this Scalar field."""
761
assert self._blob is not None, 'Value is not set for this field.'
765
"""Shortcut for self.get()"""
770
return self._metadata
772
def set_metadata(self, value):
773
assert isinstance(value, Metadata), \
774
'metadata must be Metadata, got {}'.format(type(value))
775
self._metadata = value
776
self._validate_metadata()
778
def _validate_metadata(self):
779
if self._metadata is None:
781
if (self._metadata.categorical_limit is not None and
782
self.dtype is not None):
783
assert np.issubdtype(self.dtype, np.integer), \
784
"`categorical_limit` can be specified only in integral " + \
785
"fields but got {}".format(self.dtype)
787
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
788
"""Sets only the blob field still validating the existing dtype"""
789
if self.dtype.base != np.void and throw_on_type_mismatch:
790
assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
791
assert blob.dtype.base == self.dtype.base, (
792
"Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
793
self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
795
def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
796
"""Set the type and/or blob of this scalar. See __init__ for details.
799
dtype: can be any numpy type. If not provided and `blob` is
800
provided, it will be inferred. If no argument is provided,
801
this Scalar will be of type np.void.
802
blob: if provided, can be either a BlobReference or a
803
numpy.ndarray. If a value of different type is passed,
804
a conversion to numpy.ndarray is attempted. Strings aren't
805
accepted, since they can be ambiguous. If you want to pass
806
a string, to either BlobReference(blob) or np.array(blob).
807
metadata: optional instance of Metadata, if provided overrides
808
the metadata information of the scalar
812
"Scalar should be considered immutable. Only call Scalar.set() "
813
"on newly created Scalar with unsafe=True. This will become an "
816
if blob is not None and isinstance(blob, basestring):
818
'Passing str blob to Scalar.set() is ambiguous. '
819
'Do either set(blob=np.array(blob)) or '
820
'set(blob=BlobReference(blob))'
823
self._original_dtype = dtype
831
if dtype is not None and isinstance(dtype, tuple) and dtype[1] == 1:
832
dtype = (dtype[0], (1,))
833
if dtype is not None:
834
if isinstance(dtype, tuple) and dtype[0] == np.void:
836
"Cannot set the Scalar with type {} for blob {}."
837
"If this blob is the output of some operation, "
838
"please verify the input of that operation has "
839
"proper type.".format(dtype, blob)
841
dtype = np.dtype(dtype)
844
if blob is not None and not isinstance(blob, BlobReference):
845
preserve_shape = isinstance(blob, np.ndarray)
846
if dtype is not None and dtype != np.void:
847
blob = np.array(blob, dtype=dtype.base)
849
if blob.size == 0 and not preserve_shape:
850
blob = blob.reshape((0, ) + dtype.shape)
852
assert isinstance(blob, np.ndarray), (
853
'Invalid blob type: %s' % str(type(blob)))
857
if len(blob.shape) == 0 and not preserve_shape:
858
blob = blob.reshape((1, ))
862
if (len(blob.shape) > 1 and dtype is not None and
863
dtype.base != np.void):
864
dtype = np.dtype((dtype.base, blob.shape[1:]))
867
dtype = np.dtype(np.void)
868
assert not dtype.fields, (
869
'Cannot create Scalar with a structured dtype. ' +
870
'Use from_dtype instead.'
874
if metadata is not None:
875
self.set_metadata(metadata)
876
self._validate_metadata()
878
def set_type(self, dtype):
879
self._original_dtype = dtype
880
if dtype is not None:
881
self.dtype = np.dtype(dtype)
883
self.dtype = np.dtype(np.void)
884
self._validate_metadata()
886
def _pprint_impl(self, indent, str_buffer):
887
str_buffer.write(' ' * (indent) +
888
'Scalar({!r}, {!r}, {!r})'.format(
889
self.dtype, self._blob, self._metadata) + "\n")
893
Return the zero-indexed position of this scalar field in its schema.
894
Used in order to index into the field_blob list returned by readers or
897
return self._child_base_id()
904
values_name='values',
907
"""A map is a List of Struct containing keys and values fields.
908
Optionally, you can provide custom name for the key and value fields.
911
Struct((keys_name, keys), (values_name, values)),
912
lengths_blob=lengths_blob
919
values_name='values',
923
"""A map with extra field evicted_values
925
return ListWithEvicted(
926
Struct((keys_name, keys), (values_name, values)),
927
lengths_blob=lengths_blob,
928
evicted_values=evicted_values
932
def NamedTuple(name_prefix, *fields):
933
return Struct(* [('%s_%d' % (name_prefix, i), field)
934
for i, field in enumerate(fields)])
939
Creates a Struct with default, sequential, field names of given types.
941
return NamedTuple('field', *fields)
944
def RawTuple(num_fields, name_prefix='field'):
946
Creates a tuple of `num_field` untyped scalars.
948
assert isinstance(num_fields, int)
949
assert num_fields >= 0
950
return NamedTuple(name_prefix, *([np.void] * num_fields))
953
def from_dtype(dtype, _outer_shape=()):
954
"""Constructs a Caffe2 schema from the given numpy's dtype.
956
Numpy supports scalar, array-like and structured datatypes, as long as
957
all the shapes are fixed. This function breaks down the given dtype into
958
a Caffe2 schema containing `Struct` and `Scalar` types.
960
Fields containing byte offsets are not currently supported.
962
if not isinstance(dtype, np.dtype):
965
dtype = np.dtype((dtype, _outer_shape))
968
shape = _outer_shape + dtype.shape
969
if shape != dtype.shape:
970
dtype = np.dtype((dtype.base, shape))
976
for name, (fdtype, offset) in dtype.fields:
977
assert offset == 0, ('Fields with byte offsets are not supported.')
978
struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
979
return Struct(*struct_fields)
983
"""This is a private class used to represent a Schema Node"""
985
__slots__: Sequence[str] = ("name", "children", "type_str", "field")
987
def __init__(self, name, type_str=''):
990
self.type_str = type_str
993
def add_child(self, name, type_str=''):
994
for child in self.children:
995
if child.name == name and child.type_str == type_str:
997
child = _SchemaNode(name, type_str)
998
self.children.append(child)
1001
def get_field(self):
1003
list_names = ['lengths', 'values']
1004
map_names = ['lengths', 'keys', 'values']
1006
if len(self.children) == 0 or self.field is not None:
1007
if self.field is None:
1013
for child in self.children:
1014
child_names.append(child.name)
1016
if (set(child_names) == set(list_names)):
1017
for child in self.children:
1018
if child.name == 'values':
1019
values_field = child.get_field()
1021
lengths_field = child.get_field()
1024
lengths_blob=lengths_field
1026
self.type_str = "List"
1028
elif (set(child_names) == set(map_names)):
1029
for child in self.children:
1030
if child.name == 'keys':
1031
key_field = child.get_field()
1032
elif child.name == 'values':
1033
values_field = child.get_field()
1035
lengths_field = child.get_field()
1039
lengths_blob=lengths_field
1041
self.type_str = "Map"
1046
for child in self.children:
1047
struct_fields.append((child.name, child.get_field()))
1049
self.field = Struct(*struct_fields)
1050
self.type_str = "Struct"
1053
def print_recursively(self):
1054
for child in self.children:
1055
child.print_recursively()
1056
logger.info("Printing node: Name and type")
1057
logger.info(self.name)
1058
logger.info(self.type_str)
1061
def from_column_list(
1062
col_names, col_types=None,
1063
col_blobs=None, col_metadata=None
1066
Given a list of names, types, and optionally values, construct a Schema.
1068
if col_types is None:
1069
col_types = [None] * len(col_names)
1070
if col_metadata is None:
1071
col_metadata = [None] * len(col_names)
1072
if col_blobs is None:
1073
col_blobs = [None] * len(col_names)
1074
assert len(col_names) == len(col_types), (
1075
'col_names and col_types must have the same length.'
1077
assert len(col_names) == len(col_metadata), (
1078
'col_names and col_metadata must have the same length.'
1080
assert len(col_names) == len(col_blobs), (
1081
'col_names and col_blobs must have the same length.'
1083
root = _SchemaNode('root', 'Struct')
1084
for col_name, col_type, col_blob, col_md in zip(
1085
col_names, col_types, col_blobs, col_metadata
1087
columns = col_name.split(FIELD_SEPARATOR)
1089
for i in range(len(columns)):
1093
if i == len(columns) - 1:
1100
next = current.add_child(name, type_str)
1101
if field is not None:
1105
return root.get_field()
1108
def from_blob_list(schema, values, throw_on_type_mismatch=False):
1110
Create a schema that clones the given schema, but containing the given
1113
assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
1114
if isinstance(values, BlobReference):
1116
record = schema.clone_schema()
1117
scalars = record.all_scalars()
1118
assert len(scalars) == len(values), (
1119
'Values must have %d elements, got %d.' % (len(scalars), len(values))
1121
for scalar, value in zip(scalars, values):
1122
scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
1126
def as_record(value):
1127
if isinstance(value, Field):
1129
elif isinstance(value, list) or isinstance(value, tuple):
1130
is_field_list = all(
1131
f is tuple and len(f) == 2 and isinstance(f[0], basestring)
1135
return Struct(* [(k, as_record(v)) for k, v in value])
1137
return Tuple(* [as_record(f) for f in value])
1138
elif isinstance(value, dict):
1139
return Struct(* [(k, as_record(v)) for k, v in value.items()])
1141
return _normalize_field(value)
1144
def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1146
Given a record containing BlobReferences, return a new record with same
1147
schema, containing numpy arrays, fetched from the current active workspace.
1152
return workspace.FetchBlob(str(v))
1154
return ws.blobs[str(v)].fetch()
1156
assert isinstance(blob_record, Field)
1157
field_blobs = blob_record.field_blobs()
1158
assert all(isinstance(v, BlobReference) for v in field_blobs)
1159
field_arrays = [fetch(value) for value in field_blobs]
1160
return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1163
def FeedRecord(blob_record, arrays, ws=None):
1165
Given a Record containing blob_references and arrays, which is either
1166
a list of numpy arrays or a Record containing numpy arrays, feeds the
1167
record to the current workspace.
1172
workspace.FeedBlob(str(b), v)
1174
ws.create_blob(str(b))
1175
ws.blobs[str(b)].feed(v)
1176
assert isinstance(blob_record, Field)
1177
field_blobs = blob_record.field_blobs()
1178
assert all(isinstance(v, BlobReference) for v in field_blobs)
1179
if isinstance(arrays, Field):
1181
arrays = arrays.field_blobs()
1182
assert len(arrays) == len(field_blobs), (
1183
'Values must contain exactly %d ndarrays.' % len(field_blobs)
1185
for blob, array in zip(field_blobs, arrays):
1189
def NewRecord(net, schema):
1191
Given a record of np.arrays, create a BlobReference for each one of them,
1192
returning a record containing BlobReferences. The name of each returned blob
1193
is NextScopedBlob(field_name), which guarantees unique name in the current
1194
net. Use NameScope explicitly to avoid name conflictions between different
1197
if isinstance(schema, Scalar):
1198
result = schema.clone()
1200
blob=net.NextScopedBlob('unnamed_scalar'),
1205
assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
1207
net.NextScopedBlob(prefix=name)
1208
for name in schema.field_names()
1210
return from_blob_list(schema, blob_refs)
1213
def ConstRecord(net, array_record):
1215
Given a record of arrays, returns a record of blobs,
1216
initialized with net.Const.
1218
blob_record = NewRecord(net, array_record)
1219
for blob, array in zip(
1220
blob_record.field_blobs(), array_record.field_blobs()
1222
net.Const(array, blob)
1226
def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1227
if not schema_or_record.has_blobs():
1228
record = NewRecord(net, schema_or_record)
1230
record = schema_or_record
1232
for blob_type, blob in zip(record.field_types(), record.field_blobs()):
1234
data_type = data_type_for_dtype(blob_type)
1235
shape = [0] + list(blob_type.shape)
1236
net.ConstantFill([], blob, shape=shape, dtype=data_type)
1238
logger.warning("Blob {} has type error".format(blob))
1249
net.ConstantFill([], blob, shape=[0])
1254
_DATA_TYPE_FOR_DTYPE = [
1255
(str, core.DataType.STRING),
1256
(np.float16, core.DataType.FLOAT16),
1257
(np.float32, core.DataType.FLOAT),
1258
(np.float64, core.DataType.DOUBLE),
1259
(bool, core.DataType.BOOL),
1260
(np.int8, core.DataType.INT8),
1261
(np.int16, core.DataType.INT16),
1262
(np.int32, core.DataType.INT32),
1263
(np.int64, core.DataType.INT64),
1264
(np.uint8, core.DataType.UINT8),
1265
(np.uint16, core.DataType.UINT16),
1269
def is_schema_subset(schema, original_schema):
1271
return set(schema.field_names()).issubset(
1272
set(original_schema.field_names()))
1274
def equal_schemas(schema,
1276
check_field_names=True,
1277
check_field_types=True,
1278
check_field_metas=False):
1279
assert isinstance(schema, Field)
1280
assert isinstance(original_schema, Field)
1282
if check_field_names and (
1283
schema.field_names() != original_schema.field_names()):
1285
if check_field_types and (
1286
schema.field_types() != original_schema.field_types()):
1288
if check_field_metas and (
1289
schema.field_metadata() != original_schema.field_metadata()):
1295
def schema_check(schema, previous=None):
1296
record = as_record(schema)
1297
if previous is not None:
1298
assert equal_schemas(schema, previous)
1302
def data_type_for_dtype(dtype):
1303
for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1304
if dtype.base == np_type:
1306
raise TypeError('Unknown dtype: ' + str(dtype.base))
1309
def dtype_for_core_type(core_type):
1310
for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1313
raise TypeError('Unknown core type: ' + str(core_type))
1316
def attach_metadata_to_scalars(field, metadata):
1317
for f in field.all_scalars():
1318
f.set_metadata(metadata)