cython

Форк
0
/
cdef_class_dataclass.pyx 
283 строки · 7.8 Кб
1
# mode: run
2
# tag: dataclass
3

4
from cython cimport dataclasses
5
from cython.dataclasses cimport dataclass, field
6
try:
7
    import typing
8
    from typing import ClassVar
9
    from dataclasses import InitVar
10
    import dataclasses as py_dataclasses
11
except ImportError:
12
    pass
13
import cython
14
from libc.stdlib cimport malloc, free
15

16
include "../testsupport/cythonarrayutil.pxi"
17

18
cdef class NotADataclass:
19
    cdef cython.int a
20
    b: float
21

22
    def __repr__(self):
23
        return "NADC"
24

25
    def __str__(self):
26
        return "string of NotADataclass"  # should not be called - repr is called instead!
27

28
    def __eq__(self, other):
29
        return type(self) == type(other)
30

31
    def __hash__(self):
32
        return 1
33

34
@dataclass(unsafe_hash=True)
35
cdef class BasicDataclass:
36
    """
37
    >>> sorted(list(BasicDataclass.__dataclass_fields__.keys()))
38
    ['a', 'b', 'c', 'd']
39

40
    # Check the field type attribute - this is currently a string since
41
    # it's taken from the annotation, but if we drop PEP563 in future
42
    # then it may change
43
    >>> BasicDataclass.__dataclass_fields__["a"].type
44
    'float'
45
    >>> BasicDataclass.__dataclass_fields__["b"].type
46
    'NotADataclass'
47
    >>> BasicDataclass.__dataclass_fields__["c"].type
48
    'object'
49
    >>> BasicDataclass.__dataclass_fields__["d"].type
50
    'list'
51

52
    >>> inst1 = BasicDataclass() # doctest: +ELLIPSIS
53
    Traceback (most recent call last):
54
    TypeError: __init__() takes at least 1 ...
55
    >>> inst1 = BasicDataclass(2.0)
56

57
    # The error at-least demonstrates that the hash function has been created
58
    >>> hash(inst1) # doctest: +ELLIPSIS
59
    Traceback (most recent call last):
60
    TypeError: ...unhashable...
61
    >>> inst2 = BasicDataclass(2.0)
62
    >>> inst1 == inst2
63
    True
64
    >>> inst2 = BasicDataclass(2.0, NotADataclass(), [])
65
    >>> inst1 == inst2
66
    False
67
    >>> inst2 = BasicDataclass(2.0, NotADataclass(), [], [1,2,3])
68
    >>> inst2
69
    BasicDataclass(a=2.0, b=NADC, c=[], d=[1, 2, 3])
70
    >>> inst2.c = "Some string"
71
    >>> inst2
72
    BasicDataclass(a=2.0, b=NADC, c='Some string', d=[1, 2, 3])
73
    """
74
    a: float
75
    b: NotADataclass = field(default_factory=NotADataclass)
76
    c: object = field(default=0)
77
    d: list = dataclasses.field(default_factory=list)
78

79
@dataclasses.dataclass
80
cdef class InheritsFromDataclass(BasicDataclass):
81
    """
82
    >>> sorted(list(InheritsFromDataclass.__dataclass_fields__.keys()))
83
    ['a', 'b', 'c', 'd', 'e']
84
    >>> InheritsFromDataclass(a=1.0, e=5)
85
    In __post_init__
86
    InheritsFromDataclass(a=1.0, b=NADC, c=0, d=[], e=5)
87
    """
88
    e: cython.int = 0
89

90
    def __post_init__(self):
91
        print "In __post_init__"
92

93
@cython.dataclasses.dataclass
94
cdef class InheritsFromNotADataclass(NotADataclass):
95
    """
96
    >>> sorted(list(InheritsFromNotADataclass.__dataclass_fields__.keys()))
97
    ['c']
98
    >>> InheritsFromNotADataclass()
99
    InheritsFromNotADataclass(c=1)
100
    >>> InheritsFromNotADataclass(5)
101
    InheritsFromNotADataclass(c=5)
102
    """
103

104
    c: cython.int = 1
105

106
cdef struct S:
107
    int a
108

109
ctypedef S* S_ptr
110

111
cdef S_ptr malloc_a_struct():
112
    return <S_ptr>malloc(sizeof(S))
113

114
@dataclass
115
cdef class ContainsNonPyFields:
116
    """
117
    >>> ContainsNonPyFields()  # doctest: +ELLIPSIS
118
    Traceback (most recent call last):
119
    TypeError: __init__() takes ... 1 positional ...
120
    >>> ContainsNonPyFields(mystruct={'a': 1 })  # doctest: +ELLIPSIS
121
    ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
122
    >>> ContainsNonPyFields(mystruct={'a': 1 }, memview=create_array((2,2), "c"))  # doctest: +ELLIPSIS
123
    ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
124
    >>> ContainsNonPyFields(mystruct={'a': 1 }, mystruct_ptr=0)
125
    Traceback (most recent call last):
126
    TypeError: __init__() got an unexpected keyword argument 'mystruct_ptr'
127
    """
128
    mystruct: S = cython.dataclasses.field(compare=False)
129
    mystruct_ptr: S_ptr = field(init=False, repr=False, default_factory=malloc_a_struct)
130
    memview: cython.int[:, ::1] = field(default=create_array((3,1), "c"),  # mutable so not great but OK for a test
131
                                        compare=False)
132

133
    def __dealloc__(self):
134
        free(self.mystruct_ptr)
135

136
@dataclass
137
cdef class InitClassVars:
138
    """
139
    Private (i.e. defined with "cdef") members deliberately don't appear
140
    TODO - ideally c1 and c2 should also be listed here
141
    >>> sorted(list(InitClassVars.__dataclass_fields__.keys()))
142
    ['a', 'b1', 'b2']
143
    >>> InitClassVars.c1
144
    2.0
145
    >>> InitClassVars.e1
146
    []
147
    >>> inst1 = InitClassVars()
148
    In __post_init__
149
    >>> inst1  # init vars don't appear in string
150
    InitClassVars(a=0)
151
    >>> inst2 = InitClassVars(b1=5, d2=100)
152
    In __post_init__
153
    >>> inst1 == inst2  # comparison ignores the initvar
154
    True
155
    """
156
    a: cython.int = 0
157
    b1: InitVar[cython.double] = 1.0
158
    b2: py_dataclasses.InitVar[cython.double] = 1.0
159
    c1: ClassVar[float] = 2.0
160
    c2: typing.ClassVar[float] = 2.0
161
    cdef InitVar[cython.int] d1
162
    cdef py_dataclasses.InitVar[cython.int] d2
163
    d1 = 5
164
    d2 = 5
165
    cdef ClassVar[list] e1
166
    cdef typing.ClassVar[list] e2
167
    e1 = []
168
    e2 = []
169

170
    def __post_init__(self, b1, b2, d1, d2):
171
         # Check that the initvars haven't been assigned yet
172
        assert self.b1==0, self.b1
173
        assert self.b2==0, self.b2
174
        assert self.d1==0, self.d1
175
        assert self.d2==0, self.d2
176
        self.b1 = b1
177
        self.b2 = b2
178
        self.d1 = d1
179
        self.d2 = d2
180
        print "In __post_init__"
181

182
@dataclass
183
cdef class TestVisibility:
184
    """
185
    >>> inst = TestVisibility()
186
    >>> "a" in TestVisibility.__dataclass_fields__
187
    False
188
    >>> hasattr(inst, "a")
189
    False
190
    >>> "b" in TestVisibility.__dataclass_fields__
191
    True
192
    >>> hasattr(inst, "b")
193
    True
194
    >>> "c" in TestVisibility.__dataclass_fields__
195
    True
196
    >>> TestVisibility.__dataclass_fields__["c"].type
197
    'double'
198
    >>> hasattr(inst, "c")
199
    True
200
    >>> "d" in TestVisibility.__dataclass_fields__
201
    True
202
    >>> TestVisibility.__dataclass_fields__["d"].type
203
    'object'
204
    >>> hasattr(inst, "d")
205
    True
206
    """
207
    cdef double a
208
    a = 1.0
209
    b: cython.double = 2.0
210
    cdef public double c
211
    c = 3.0
212
    cdef public object d
213
    d = object()
214

215
@dataclass(frozen=True)
216
cdef class TestFrozen:
217
    """
218
    >>> inst = TestFrozen(a=5)
219
    >>> inst.a
220
    5.0
221
    >>> inst.a = 2.  # doctest: +ELLIPSIS
222
    Traceback (most recent call last):
223
    AttributeError: attribute 'a' of '...TestFrozen' objects is not writable
224
    """
225
    a: cython.double = 2.0
226

227
def get_dataclass_initvar():
228
    return py_dataclasses.InitVar
229

230

231
@dataclass(kw_only=True)
232
cdef class TestKwOnly:
233
    """
234
    >>> inst = TestKwOnly(a=3, b=2)
235
    >>> inst.a
236
    3.0
237
    >>> inst.b
238
    2
239
    >>> inst = TestKwOnly(b=2)
240
    >>> inst.a
241
    2.0
242
    >>> inst.b
243
    2
244
    >>> fail = TestKwOnly(3, 2)
245
    Traceback (most recent call last):
246
    TypeError: __init__() takes exactly 0 positional arguments (2 given)
247
    >>> fail = TestKwOnly(a=3)
248
    Traceback (most recent call last):
249
    TypeError: __init__() needs keyword-only argument b
250
    >>> fail = TestKwOnly()
251
    Traceback (most recent call last):
252
    TypeError: __init__() needs keyword-only argument b
253
    """
254

255
    a: cython.double = 2.0
256
    b: cython.long
257

258

259
__doc__ = """
260
>>> from dataclasses import Field, is_dataclass, fields, InitVar
261

262
# It uses the types from the standard library where available
263
>>> all(isinstance(v, Field) for v in BasicDataclass.__dataclass_fields__.values())
264
True
265

266
# check out Cython dataclasses are close enough to convince it
267
>>> is_dataclass(BasicDataclass)
268
True
269
>>> is_dataclass(BasicDataclass(1.5))
270
True
271
>>> is_dataclass(InheritsFromDataclass)
272
True
273
>>> is_dataclass(NotADataclass)
274
False
275
>>> is_dataclass(InheritsFromNotADataclass)
276
True
277
>>> [ f.name for f in fields(BasicDataclass)]
278
['a', 'b', 'c', 'd']
279
>>> [ f.name for f in fields(InitClassVars)]
280
['a']
281
>>> get_dataclass_initvar() == InitVar
282
True
283
"""
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.