4
from cython cimport dataclasses
5
from cython.dataclasses cimport dataclass, field
8
from typing import ClassVar
9
from dataclasses import InitVar
10
import dataclasses as py_dataclasses
14
from libc.stdlib cimport malloc, free
16
include "../testsupport/cythonarrayutil.pxi"
18
cdef class NotADataclass:
26
return "string of NotADataclass" # should not be called - repr is called instead!
28
def __eq__(self, other):
29
return type(self) == type(other)
34
@dataclass(unsafe_hash=True)
35
cdef class BasicDataclass:
37
>>> sorted(list(BasicDataclass.__dataclass_fields__.keys()))
40
# Check the field type attribute - this is currently a string since
41
# it's taken from the annotation, but if we drop PEP563 in future
43
>>> BasicDataclass.__dataclass_fields__["a"].type
45
>>> BasicDataclass.__dataclass_fields__["b"].type
47
>>> BasicDataclass.__dataclass_fields__["c"].type
49
>>> BasicDataclass.__dataclass_fields__["d"].type
52
>>> inst1 = BasicDataclass() # doctest: +ELLIPSIS
53
Traceback (most recent call last):
54
TypeError: __init__() takes at least 1 ...
55
>>> inst1 = BasicDataclass(2.0)
57
# The error at-least demonstrates that the hash function has been created
58
>>> hash(inst1) # doctest: +ELLIPSIS
59
Traceback (most recent call last):
60
TypeError: ...unhashable...
61
>>> inst2 = BasicDataclass(2.0)
64
>>> inst2 = BasicDataclass(2.0, NotADataclass(), [])
67
>>> inst2 = BasicDataclass(2.0, NotADataclass(), [], [1,2,3])
69
BasicDataclass(a=2.0, b=NADC, c=[], d=[1, 2, 3])
70
>>> inst2.c = "Some string"
72
BasicDataclass(a=2.0, b=NADC, c='Some string', d=[1, 2, 3])
75
b: NotADataclass = field(default_factory=NotADataclass)
76
c: object = field(default=0)
77
d: list = dataclasses.field(default_factory=list)
80
cdef class InheritsFromDataclass(BasicDataclass):
82
>>> sorted(list(InheritsFromDataclass.__dataclass_fields__.keys()))
83
['a', 'b', 'c', 'd', 'e']
84
>>> InheritsFromDataclass(a=1.0, e=5)
86
InheritsFromDataclass(a=1.0, b=NADC, c=0, d=[], e=5)
90
def __post_init__(self):
91
print "In __post_init__"
93
@cython.dataclasses.dataclass
94
cdef class InheritsFromNotADataclass(NotADataclass):
96
>>> sorted(list(InheritsFromNotADataclass.__dataclass_fields__.keys()))
98
>>> InheritsFromNotADataclass()
99
InheritsFromNotADataclass(c=1)
100
>>> InheritsFromNotADataclass(5)
101
InheritsFromNotADataclass(c=5)
111
cdef S_ptr malloc_a_struct():
112
return <S_ptr>malloc(sizeof(S))
115
cdef class ContainsNonPyFields:
117
>>> ContainsNonPyFields() # doctest: +ELLIPSIS
118
Traceback (most recent call last):
119
TypeError: __init__() takes ... 1 positional ...
120
>>> ContainsNonPyFields(mystruct={'a': 1 }) # doctest: +ELLIPSIS
121
ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
122
>>> ContainsNonPyFields(mystruct={'a': 1 }, memview=create_array((2,2), "c")) # doctest: +ELLIPSIS
123
ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
124
>>> ContainsNonPyFields(mystruct={'a': 1 }, mystruct_ptr=0)
125
Traceback (most recent call last):
126
TypeError: __init__() got an unexpected keyword argument 'mystruct_ptr'
128
mystruct: S = cython.dataclasses.field(compare=False)
129
mystruct_ptr: S_ptr = field(init=False, repr=False, default_factory=malloc_a_struct)
130
memview: cython.int[:, ::1] = field(default=create_array((3,1), "c"), # mutable so not great but OK for a test
133
def __dealloc__(self):
134
free(self.mystruct_ptr)
137
cdef class InitClassVars:
139
Private (i.e. defined with "cdef") members deliberately don't appear
140
TODO - ideally c1 and c2 should also be listed here
141
>>> sorted(list(InitClassVars.__dataclass_fields__.keys()))
147
>>> inst1 = InitClassVars()
149
>>> inst1 # init vars don't appear in string
151
>>> inst2 = InitClassVars(b1=5, d2=100)
153
>>> inst1 == inst2 # comparison ignores the initvar
157
b1: InitVar[cython.double] = 1.0
158
b2: py_dataclasses.InitVar[cython.double] = 1.0
159
c1: ClassVar[float] = 2.0
160
c2: typing.ClassVar[float] = 2.0
161
cdef InitVar[cython.int] d1
162
cdef py_dataclasses.InitVar[cython.int] d2
165
cdef ClassVar[list] e1
166
cdef typing.ClassVar[list] e2
170
def __post_init__(self, b1, b2, d1, d2):
171
# Check that the initvars haven't been assigned yet
172
assert self.b1==0, self.b1
173
assert self.b2==0, self.b2
174
assert self.d1==0, self.d1
175
assert self.d2==0, self.d2
180
print "In __post_init__"
183
cdef class TestVisibility:
185
>>> inst = TestVisibility()
186
>>> "a" in TestVisibility.__dataclass_fields__
188
>>> hasattr(inst, "a")
190
>>> "b" in TestVisibility.__dataclass_fields__
192
>>> hasattr(inst, "b")
194
>>> "c" in TestVisibility.__dataclass_fields__
196
>>> TestVisibility.__dataclass_fields__["c"].type
198
>>> hasattr(inst, "c")
200
>>> "d" in TestVisibility.__dataclass_fields__
202
>>> TestVisibility.__dataclass_fields__["d"].type
204
>>> hasattr(inst, "d")
209
b: cython.double = 2.0
215
@dataclass(frozen=True)
216
cdef class TestFrozen:
218
>>> inst = TestFrozen(a=5)
221
>>> inst.a = 2. # doctest: +ELLIPSIS
222
Traceback (most recent call last):
223
AttributeError: attribute 'a' of '...TestFrozen' objects is not writable
225
a: cython.double = 2.0
227
def get_dataclass_initvar():
228
return py_dataclasses.InitVar
231
@dataclass(kw_only=True)
232
cdef class TestKwOnly:
234
>>> inst = TestKwOnly(a=3, b=2)
239
>>> inst = TestKwOnly(b=2)
244
>>> fail = TestKwOnly(3, 2)
245
Traceback (most recent call last):
246
TypeError: __init__() takes exactly 0 positional arguments (2 given)
247
>>> fail = TestKwOnly(a=3)
248
Traceback (most recent call last):
249
TypeError: __init__() needs keyword-only argument b
250
>>> fail = TestKwOnly()
251
Traceback (most recent call last):
252
TypeError: __init__() needs keyword-only argument b
255
a: cython.double = 2.0
260
>>> from dataclasses import Field, is_dataclass, fields, InitVar
262
# It uses the types from the standard library where available
263
>>> all(isinstance(v, Field) for v in BasicDataclass.__dataclass_fields__.values())
266
# check out Cython dataclasses are close enough to convince it
267
>>> is_dataclass(BasicDataclass)
269
>>> is_dataclass(BasicDataclass(1.5))
271
>>> is_dataclass(InheritsFromDataclass)
273
>>> is_dataclass(NotADataclass)
275
>>> is_dataclass(InheritsFromNotADataclass)
277
>>> [ f.name for f in fields(BasicDataclass)]
279
>>> [ f.name for f in fields(InitClassVars)]
281
>>> get_dataclass_initvar() == InitVar