llama

Форк
0
/
json_schema_to_grammar.py 
811 строк · 32.9 Кб
1
#!/usr/bin/env python3
2
from __future__ import annotations
3

4
import argparse
5
import itertools
6
import json
7
import re
8
import sys
9
from typing import Any, List, Optional, Set, Tuple, Union
10

11
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
12

13
    if min_items == 0 and max_items == 1:
14
        return f'{item_rule}?'
15

16
    if not separator_rule:
17
        if min_items == 1 and max_items is None:
18
            return f'{item_rule}+'
19
        elif min_items == 0 and max_items is None:
20
            return f'{item_rule}*'
21
        else:
22
            return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
23

24
    result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
25
    return f'({result})?' if min_items == 0 else result
26

27
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
28
    has_min = min_value != None
29
    has_max = max_value != None
30

31
    def digit_range(from_char: str, to_char: str):
32
        out.append("[")
33
        if from_char == to_char:
34
            out.append(from_char)
35
        else:
36
            out.append(from_char)
37
            out.append("-")
38
            out.append(to_char)
39
        out.append("]")
40

41
    def more_digits(min_digits: int, max_digits: int):
42
        out.append("[0-9]")
43
        if min_digits == max_digits and min_digits == 1:
44
            return
45
        out.append("{")
46
        out.append(str(min_digits))
47
        if max_digits != min_digits:
48
            out.append(",")
49
            if max_digits != sys.maxsize:
50
                out.append(str(max_digits))
51
        out.append("}")
52

53
    def uniform_range(from_str: str, to_str: str):
54
        i = 0
55
        while i < len(from_str) and from_str[i] == to_str[i]:
56
            i += 1
57
        if i > 0:
58
            out.append("\"")
59
            out.append(from_str[:i])
60
            out.append("\"")
61
        if i < len(from_str):
62
            if i > 0:
63
                out.append(" ")
64
            sub_len = len(from_str) - i - 1
65
            if sub_len > 0:
66
                from_sub = from_str[i+1:]
67
                to_sub = to_str[i+1:]
68
                sub_zeros = "0" * sub_len
69
                sub_nines = "9" * sub_len
70

71
                to_reached = False
72
                out.append("(")
73
                if from_sub == sub_zeros:
74
                    digit_range(from_str[i], chr(ord(to_str[i]) - 1))
75
                    out.append(" ")
76
                    more_digits(sub_len, sub_len)
77
                else:
78
                    out.append("[")
79
                    out.append(from_str[i])
80
                    out.append("] ")
81
                    out.append("(")
82
                    uniform_range(from_sub, sub_nines)
83
                    out.append(")")
84
                    if ord(from_str[i]) < ord(to_str[i]) - 1:
85
                        out.append(" | ")
86
                        if to_sub == sub_nines:
87
                            digit_range(chr(ord(from_str[i]) + 1), to_str[i])
88
                            to_reached = True
89
                        else:
90
                            digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
91
                        out.append(" ")
92
                        more_digits(sub_len, sub_len)
93
                if not to_reached:
94
                    out.append(" | ")
95
                    digit_range(to_str[i], to_str[i])
96
                    out.append(" ")
97
                    uniform_range(sub_zeros, to_sub)
98
                out.append(")")
99
            else:
100
                out.append("[")
101
                out.append(from_str[i])
102
                out.append("-")
103
                out.append(to_str[i])
104
                out.append("]")
105

106
    if has_min and has_max:
107
        if min_value < 0 and max_value < 0:
108
            out.append("\"-\" (")
109
            _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
110
            out.append(")")
111
            return
112

113
        if min_value < 0:
114
            out.append("\"-\" (")
115
            _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
116
            out.append(") | ")
117
            min_value = 0
118

119
        min_s = str(min_value)
120
        max_s = str(max_value)
121
        min_digits = len(min_s)
122
        max_digits = len(max_s)
123

124
        for digits in range(min_digits, max_digits):
125
            uniform_range(min_s, "9" * digits)
126
            min_s = "1" + "0" * digits
127
            out.append(" | ")
128
        uniform_range(min_s, max_s)
129
        return
130

131
    less_decimals = max(decimals_left - 1, 1)
132

133
    if has_min:
134
        if min_value < 0:
135
            out.append("\"-\" (")
136
            _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
137
            out.append(") | [0] | [1-9] ")
138
            more_digits(0, decimals_left - 1)
139
        elif min_value == 0:
140
            if top_level:
141
                out.append("[0] | [1-9] ")
142
                more_digits(0, less_decimals)
143
            else:
144
                more_digits(1, decimals_left)
145
        elif min_value <= 9:
146
            c = str(min_value)
147
            range_start = '1' if top_level else '0'
148
            if c > range_start:
149
                digit_range(range_start, chr(ord(c) - 1))
150
                out.append(" ")
151
                more_digits(1, less_decimals)
152
                out.append(" | ")
153
            digit_range(c, "9")
154
            out.append(" ")
155
            more_digits(0, less_decimals)
156
        else:
157
            min_s = str(min_value)
158
            length = len(min_s)
159
            c = min_s[0]
160

161
            if c > "1":
162
                digit_range("1" if top_level else "0", chr(ord(c) - 1))
163
                out.append(" ")
164
                more_digits(length, less_decimals)
165
                out.append(" | ")
166
            digit_range(c, c)
167
            out.append(" (")
168
            _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
169
            out.append(")")
170
            if c < "9":
171
                out.append(" | ")
172
                digit_range(chr(ord(c) + 1), "9")
173
                out.append(" ")
174
                more_digits(length - 1, less_decimals)
175
        return
176

177
    if has_max:
178
        if max_value >= 0:
179
            if top_level:
180
                out.append("\"-\" [1-9] ")
181
                more_digits(0, less_decimals)
182
                out.append(" | ")
183
            _generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
184
        else:
185
            out.append("\"-\" (")
186
            _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
187
            out.append(")")
188
        return
189

190
    raise RuntimeError("At least one of min_value or max_value must be set")
191

192
class BuiltinRule:
193
    def __init__(self, content: str, deps: list | None = None):
194
        self.content = content
195
        self.deps = deps or []
196

197
# Constraining spaces to prevent model "running away".
198
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
199

200
PRIMITIVE_RULES = {
201
    'boolean'      : BuiltinRule('("true" | "false") space', []),
202
    'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
203
    'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
204
    'number'       : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
205
    'integer'      : BuiltinRule('("-"? integral-part) space', ['integral-part']),
206
    'value'        : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
207
    'object'       : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
208
    'array'        : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
209
    'uuid'         : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
210
    'char'         : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
211
    'string'       : BuiltinRule(r'"\"" char* "\"" space', ['char']),
212
    'null'         : BuiltinRule('"null" space', []),
213
}
214

215
# TODO: support "uri", "email" string formats
216
STRING_FORMAT_RULES = {
217
    'date'            : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
218
    'time'            : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
219
    'date-time'       : BuiltinRule('date "T" time', ['date', 'time']),
220
    'date-string'     : BuiltinRule('"\\"" date "\\"" space', ['date']),
221
    'time-string'     : BuiltinRule('"\\"" time "\\"" space', ['time']),
222
    'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
223
}
224

225
DOTALL = '[\\U00000000-\\U0010FFFF]'
226
DOT = '[^\\x0A\\x0D]'
227

228
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
229

230
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
231
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
232
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
233
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
234

235
NON_LITERAL_SET = set('|.()[]{}*+?')
236
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
237

238

239
class SchemaConverter:
240
    def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
241
        self._prop_order = prop_order
242
        self._allow_fetch = allow_fetch
243
        self._dotall = dotall
244
        self._raw_pattern = raw_pattern
245
        self._rules = {
246
            'space': SPACE_RULE,
247
        }
248
        self._refs = {}
249
        self._refs_being_resolved = set()
250

251
    def _format_literal(self, literal):
252
        escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
253
            lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
254
        )
255
        return f'"{escaped}"'
256

257
    def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
258
        '''
259
            not_literal('a') -> '[^a]'
260
            not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
261
        '''
262
        assert len(literal) > 0, 'Empty literal not supported'
263
        def recurse(i: int):
264
            c = literal[i]
265
            if maybe_escaped_underscores and c == '_':
266
                yield f'[^{c}\\\\]'
267
                yield ' | '
268
                yield f'"\\\\"? "{c}"'
269
            else:
270
                yield f'[^{c}]'
271
            if i < len(literal) - 1:
272
                yield ' | '
273
                yield self._format_literal(c)
274
                yield ' ('
275
                yield from recurse(i + 1)
276
                yield ')?'
277

278
        return ''.join(('(', *recurse(0), ')'))
279

280
    def _not_strings(self, strings):
281
        class TrieNode:
282
            def __init__(self):
283
                self.children = {}
284
                self.is_end_of_string = False
285

286
            def insert(self, string):
287
                node = self
288
                for c in string:
289
                    node = node.children.setdefault(c, TrieNode())
290
                node.is_end_of_string = True
291

292
        trie = TrieNode()
293
        for s in strings:
294
            trie.insert(s)
295

296
        char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
297
        out = ['["] ( ']
298

299
        def visit(node):
300
            rejects = []
301
            first = True
302
            for c in sorted(node.children.keys()):
303
                child = node.children[c]
304
                rejects.append(c)
305
                if first:
306
                    first = False
307
                else:
308
                    out.append(' | ')
309
                out.append(f'[{c}]')
310
                if child.children:
311
                    out.append(f' (')
312
                    visit(child)
313
                    out.append(')')
314
                elif child.is_end_of_string:
315
                    out.append(f' {char_rule}+')
316
            if node.children:
317
                if not first:
318
                    out.append(' | ')
319
                out.append(f'[^"{"".join(rejects)}] {char_rule}*')
320
        visit(trie)
321

322
        out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
323
        return ''.join(out)
324

325
    def _add_rule(self, name, rule):
326
        esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
327
        if esc_name not in self._rules or self._rules[esc_name] == rule:
328
            key = esc_name
329
        else:
330
            i = 0
331
            while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
332
                i += 1
333
            key = f'{esc_name}{i}'
334
        self._rules[key] = rule
335
        return key
336

337
    def resolve_refs(self, schema: dict, url: str):
338
        '''
339
            Resolves all $ref fields in the given schema, fetching any remote schemas,
340
            replacing $ref with absolute reference URL and populating self._refs with the
341
            respective referenced (sub)schema dictionaries.
342
        '''
343
        def visit(n: dict):
344
            if isinstance(n, list):
345
                return [visit(x) for x in n]
346
            elif isinstance(n, dict):
347
                ref = n.get('$ref')
348
                if ref is not None and ref not in self._refs:
349
                    if ref.startswith('https://'):
350
                        assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
351
                        import requests
352

353
                        frag_split = ref.split('#')
354
                        base_url = frag_split[0]
355

356
                        target = self._refs.get(base_url)
357
                        if target is None:
358
                            target = self.resolve_refs(requests.get(ref).json(), base_url)
359
                            self._refs[base_url] = target
360

361
                        if len(frag_split) == 1 or frag_split[-1] == '':
362
                            return target
363
                    elif ref.startswith('#/'):
364
                        target = schema
365
                        ref = f'{url}{ref}'
366
                        n['$ref'] = ref
367
                    else:
368
                        raise ValueError(f'Unsupported ref {ref}')
369

370
                    for sel in ref.split('#')[-1].split('/')[1:]:
371
                        assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
372
                        target = target[sel]
373

374
                    self._refs[ref] = target
375
                else:
376
                    for v in n.values():
377
                        visit(v)
378

379
            return n
380
        return visit(schema)
381

382
    def _generate_union_rule(self, name, alt_schemas):
383
        return ' | '.join((
384
            self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
385
            for i, alt_schema in enumerate(alt_schemas)
386
        ))
387

388
    def _visit_pattern(self, pattern, name):
389
        '''
390
            Transforms a regular expression pattern into a GBNF rule.
391

392
            Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
393
            Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
394

395
            Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
396

397
            Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
398
            we define sub-rules to keep the output lean.
399
        '''
400

401
        assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
402
        pattern = pattern[1:-1]
403
        sub_rule_ids = {}
404

405
        i = 0
406
        length = len(pattern)
407

408
        def to_rule(s: tuple[str, bool]) -> str:
409
            (txt, is_literal) = s
410
            return "\"" + txt + "\"" if is_literal else txt
411

412
        def transform() -> tuple[str, bool]:
413
            '''
414
                Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
415
            '''
416
            nonlocal i
417
            nonlocal pattern
418
            nonlocal sub_rule_ids
419

420
            start = i
421
            # For each component of this sequence, store its string representation and whether it's a literal.
422
            # We only need a flat structure here to apply repetition operators to the last item, and
423
            # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
424
            # (GBNF's syntax is luckily very close to regular expressions!)
425
            seq: list[tuple[str, bool]] = []
426

427
            def get_dot():
428
                if self._dotall:
429
                    rule = DOTALL
430
                else:
431
                    # Accept any character... except \n and \r line break chars (\x0A and \xOD)
432
                    rule = DOT
433
                return self._add_rule(f'dot', rule)
434

435
            def join_seq():
436
                nonlocal seq
437
                ret = []
438
                for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
439
                    if is_literal:
440
                        ret.append((''.join(x[0] for x in g), True))
441
                    else:
442
                        ret.extend(g)
443
                if len(ret) == 1:
444
                    return ret[0]
445
                return (' '.join(to_rule(x) for x in seq), False)
446

447
            while i < length:
448
                c = pattern[i]
449
                if c == '.':
450
                    seq.append((get_dot(), False))
451
                    i += 1
452
                elif c == '(':
453
                    i += 1
454
                    if i < length:
455
                        assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
456
                    seq.append((f'({to_rule(transform())})', False))
457
                elif c == ')':
458
                    i += 1
459
                    assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
460
                    return join_seq()
461
                elif c == '[':
462
                    square_brackets = c
463
                    i += 1
464
                    while i < length and pattern[i] != ']':
465
                        if pattern[i] == '\\':
466
                            square_brackets += pattern[i:i+2]
467
                            i += 2
468
                        else:
469
                            square_brackets += pattern[i]
470
                            i += 1
471
                    assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
472
                    square_brackets += ']'
473
                    i += 1
474
                    seq.append((square_brackets, False))
475
                elif c == '|':
476
                    seq.append(('|', False))
477
                    i += 1
478
                elif c in ('*', '+', '?'):
479
                    seq[-1] = (to_rule(seq[-1]) + c, False)
480
                    i += 1
481
                elif c == '{':
482
                    curly_brackets = c
483
                    i += 1
484
                    while i < length and pattern[i] != '}':
485
                        curly_brackets += pattern[i]
486
                        i += 1
487
                    assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
488
                    curly_brackets += '}'
489
                    i += 1
490
                    nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
491
                    min_times = 0
492
                    max_times = None
493
                    try:
494
                        if len(nums) == 1:
495
                            min_times = int(nums[0])
496
                            max_times = min_times
497
                        else:
498
                            assert len(nums) == 2
499
                            min_times = int(nums[0]) if nums[0] else 0
500
                            max_times = int(nums[1]) if nums[1] else None
501
                    except ValueError:
502
                        raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
503

504
                    (sub, sub_is_literal) = seq[-1]
505

506
                    if not sub_is_literal:
507
                        id = sub_rule_ids.get(sub)
508
                        if id is None:
509
                            id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
510
                            sub_rule_ids[sub] = id
511
                        sub = id
512

513
                    seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
514
                else:
515
                    literal = ''
516
                    while i < length:
517
                        if pattern[i] == '\\' and i < length - 1:
518
                            next = pattern[i + 1]
519
                            if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
520
                                i += 1
521
                                literal += pattern[i]
522
                                i += 1
523
                            else:
524
                                literal += pattern[i:i+2]
525
                                i += 2
526
                        elif pattern[i] == '"' and not self._raw_pattern:
527
                            literal += '\\"'
528
                            i += 1
529
                        elif pattern[i] not in NON_LITERAL_SET and \
530
                                (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
531
                            literal += pattern[i]
532
                            i += 1
533
                        else:
534
                            break
535
                    if literal:
536
                        seq.append((literal, True))
537

538
            return join_seq()
539

540
        return self._add_rule(
541
            name,
542
            to_rule(transform()) if self._raw_pattern \
543
                else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
544

545

546
    def _resolve_ref(self, ref):
547
        ref_name = ref.split('/')[-1]
548
        if ref_name not in self._rules and ref not in self._refs_being_resolved:
549
            self._refs_being_resolved.add(ref)
550
            resolved = self._refs[ref]
551
            ref_name = self.visit(resolved, ref_name)
552
            self._refs_being_resolved.remove(ref)
553
        return ref_name
554

555
    def _generate_constant_rule(self, value):
556
        return self._format_literal(json.dumps(value))
557

558
    def visit(self, schema, name):
559
        schema_type = schema.get('type')
560
        schema_format = schema.get('format')
561
        rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
562

563
        if (ref := schema.get('$ref')) is not None:
564
            return self._add_rule(rule_name, self._resolve_ref(ref))
565

566
        elif 'oneOf' in schema or 'anyOf' in schema:
567
            return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
568

569
        elif isinstance(schema_type, list):
570
            return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
571

572
        elif 'const' in schema:
573
            return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
574

575
        elif 'enum' in schema:
576
            rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
577
            return self._add_rule(rule_name, rule)
578

579
        elif schema_type in (None, 'object') and \
580
             ('properties' in schema or \
581
              ('additionalProperties' in schema and schema['additionalProperties'] is not True)):
582
            required = set(schema.get('required', []))
583
            properties = list(schema.get('properties', {}).items())
584
            return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
585

586
        elif schema_type in (None, 'object') and 'allOf' in schema:
587
            required = set()
588
            properties = []
589
            hybrid_name = name
590
            def add_component(comp_schema, is_required):
591
                if (ref := comp_schema.get('$ref')) is not None:
592
                    comp_schema = self._refs[ref]
593

594
                if 'properties' in comp_schema:
595
                    for prop_name, prop_schema in comp_schema['properties'].items():
596
                        properties.append((prop_name, prop_schema))
597
                        if is_required:
598
                            required.add(prop_name)
599

600
            for t in schema['allOf']:
601
                if 'anyOf' in t:
602
                    for tt in t['anyOf']:
603
                        add_component(tt, is_required=False)
604
                else:
605
                    add_component(t, is_required=True)
606

607
            return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
608

609
        elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
610
            items = schema.get('items') or schema['prefixItems']
611
            if isinstance(items, list):
612
                return self._add_rule(
613
                    rule_name,
614
                    '"[" space ' +
615
                    ' "," space '.join(
616
                        self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
617
                        for i, item in enumerate(items)) +
618
                    ' "]" space')
619
            else:
620
                item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
621
                min_items = schema.get("minItems", 0)
622
                max_items = schema.get("maxItems")
623
                return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
624

625
        elif schema_type in (None, 'string') and 'pattern' in schema:
626
            return self._visit_pattern(schema['pattern'], rule_name)
627

628
        elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
629
            return self._add_primitive(
630
                'root' if rule_name == 'root' else schema_format,
631
                PRIMITIVE_RULES['uuid']
632
            )
633

634
        elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
635
            prim_name = f'{schema_format}-string'
636
            return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
637

638
        elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
639
            char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
640
            min_len = schema.get('minLength', 0)
641
            max_len = schema.get('maxLength')
642

643
            return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
644

645
        elif schema_type in (None, 'integer') and \
646
                ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
647
            min_value = None
648
            max_value = None
649
            if 'minimum' in schema:
650
                min_value = schema['minimum']
651
            elif 'exclusiveMinimum' in schema:
652
                min_value = schema['exclusiveMinimum'] + 1
653
            if 'maximum' in schema:
654
                max_value = schema['maximum']
655
            elif 'exclusiveMaximum' in schema:
656
                max_value = schema['exclusiveMaximum'] - 1
657

658
            out = ["("]
659
            _generate_min_max_int(min_value, max_value, out)
660
            out.append(") space")
661
            return self._add_rule(rule_name, ''.join(out))
662

663
        elif (schema_type == 'object') or (len(schema) == 0):
664
            return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
665

666
        else:
667
            assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
668
            # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
669
            return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
670

671
    def _add_primitive(self, name: str, rule: BuiltinRule):
672
        n = self._add_rule(name, rule.content)
673

674
        for dep in rule.deps:
675
            dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
676
            assert dep_rule, f'Rule {dep} not known'
677
            if dep not in self._rules:
678
                self._add_primitive(dep, dep_rule)
679
        return n
680

681
    def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
682
        prop_order = self._prop_order
683
        # sort by position in prop_order (if specified) then by original order
684
        sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
685

686
        prop_kv_rule_names = {}
687
        for prop_name, prop_schema in properties:
688
            prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
689
            prop_kv_rule_names[prop_name] = self._add_rule(
690
                f'{name}{"-" if name else ""}{prop_name}-kv',
691
                fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
692
            )
693
        required_props = [k for k in sorted_props if k in required]
694
        optional_props = [k for k in sorted_props if k not in required]
695

696
        if additional_properties is not None and additional_properties != False:
697
            sub_name = f'{name}{"-" if name else ""}additional'
698
            value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
699
                self._add_primitive('value', PRIMITIVE_RULES['value'])
700
            key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
701
                else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
702

703
            prop_kv_rule_names["*"] = self._add_rule(
704
                f'{sub_name}-kv',
705
                f'{key_rule} ":" space {value_rule}'
706
            )
707
            optional_props.append("*")
708

709
        rule = '"{" space '
710
        rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
711

712
        if optional_props:
713
            rule += ' ('
714
            if required_props:
715
                rule += ' "," space ( '
716

717
            def get_recursive_refs(ks, first_is_optional):
718
                [k, *rest] = ks
719
                kv_rule_name = prop_kv_rule_names[k]
720
                comma_ref = f'( "," space {kv_rule_name} )'
721
                if first_is_optional:
722
                    res = comma_ref + ('*' if k == '*' else '?')
723
                else:
724
                    res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
725
                if len(rest) > 0:
726
                    res += ' ' + self._add_rule(
727
                        f'{name}{"-" if name else ""}{k}-rest',
728
                        get_recursive_refs(rest, first_is_optional=True)
729
                    )
730
                return res
731

732
            rule += ' | '.join(
733
                get_recursive_refs(optional_props[i:], first_is_optional=False)
734
                for i in range(len(optional_props))
735
            )
736
            if required_props:
737
                rule += ' )'
738
            rule += ' )?'
739

740
        rule += ' "}" space'
741

742
        return rule
743

744
    def format_grammar(self):
745
        return '\n'.join(
746
            f'{name} ::= {rule}'
747
            for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
748
        )
749

750

751
def main(args_in = None):
752
    parser = argparse.ArgumentParser(
753
        description='''
754
            Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
755
            given JSON schema. Only a subset of JSON schema features are supported; more may be
756
            added in the future.
757
        ''',
758
    )
759
    parser.add_argument(
760
        '--prop-order',
761
        default=[],
762
        type=lambda s: s.split(','),
763
        help='''
764
            comma-separated property names defining the order of precedence for object properties;
765
            properties not specified here are given lower precedence than those that are, and
766
            are kept in their original order from the schema. Required properties are always
767
            given precedence over optional properties.
768
        '''
769
    )
770
    parser.add_argument(
771
        '--allow-fetch',
772
        action='store_true',
773
        default=False,
774
        help='Whether to allow fetching referenced schemas over HTTPS')
775
    parser.add_argument(
776
        '--dotall',
777
        action='store_true',
778
        default=False,
779
        help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
780
    parser.add_argument(
781
        '--raw-pattern',
782
        action='store_true',
783
        default=False,
784
        help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
785

786
    parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
787
    args = parser.parse_args(args_in)
788

789
    if args.schema.startswith('https://'):
790
        url = args.schema
791
        import requests
792
        schema = requests.get(url).json()
793
    elif args.schema == '-':
794
        url = 'stdin'
795
        schema = json.load(sys.stdin)
796
    else:
797
        url = f'file://{args.schema}'
798
        with open(args.schema) as f:
799
            schema = json.load(f)
800
    converter = SchemaConverter(
801
        prop_order={name: idx for idx, name in enumerate(args.prop_order)},
802
        allow_fetch=args.allow_fetch,
803
        dotall=args.dotall,
804
        raw_pattern=args.raw_pattern)
805
    schema = converter.resolve_refs(schema, url)
806
    converter.visit(schema, '')
807
    print(converter.format_grammar())
808

809

810
if __name__ == '__main__':
811
    main()
812

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.