2
from __future__ import annotations
9
from typing import Any, List, Optional, Set, Tuple, Union
11
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
13
if min_items == 0 and max_items == 1:
14
return f'{item_rule}?'
16
if not separator_rule:
17
if min_items == 1 and max_items is None:
18
return f'{item_rule}+'
19
elif min_items == 0 and max_items is None:
20
return f'{item_rule}*'
22
return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
24
result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
25
return f'({result})?' if min_items == 0 else result
27
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
28
has_min = min_value != None
29
has_max = max_value != None
31
def digit_range(from_char: str, to_char: str):
33
if from_char == to_char:
41
def more_digits(min_digits: int, max_digits: int):
43
if min_digits == max_digits and min_digits == 1:
46
out.append(str(min_digits))
47
if max_digits != min_digits:
49
if max_digits != sys.maxsize:
50
out.append(str(max_digits))
53
def uniform_range(from_str: str, to_str: str):
55
while i < len(from_str) and from_str[i] == to_str[i]:
59
out.append(from_str[:i])
64
sub_len = len(from_str) - i - 1
66
from_sub = from_str[i+1:]
68
sub_zeros = "0" * sub_len
69
sub_nines = "9" * sub_len
73
if from_sub == sub_zeros:
74
digit_range(from_str[i], chr(ord(to_str[i]) - 1))
76
more_digits(sub_len, sub_len)
79
out.append(from_str[i])
82
uniform_range(from_sub, sub_nines)
84
if ord(from_str[i]) < ord(to_str[i]) - 1:
86
if to_sub == sub_nines:
87
digit_range(chr(ord(from_str[i]) + 1), to_str[i])
90
digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
92
more_digits(sub_len, sub_len)
95
digit_range(to_str[i], to_str[i])
97
uniform_range(sub_zeros, to_sub)
101
out.append(from_str[i])
103
out.append(to_str[i])
106
if has_min and has_max:
107
if min_value < 0 and max_value < 0:
108
out.append("\"-\" (")
109
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
114
out.append("\"-\" (")
115
_generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
119
min_s = str(min_value)
120
max_s = str(max_value)
121
min_digits = len(min_s)
122
max_digits = len(max_s)
124
for digits in range(min_digits, max_digits):
125
uniform_range(min_s, "9" * digits)
126
min_s = "1" + "0" * digits
128
uniform_range(min_s, max_s)
131
less_decimals = max(decimals_left - 1, 1)
135
out.append("\"-\" (")
136
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
137
out.append(") | [0] | [1-9] ")
138
more_digits(0, decimals_left - 1)
141
out.append("[0] | [1-9] ")
142
more_digits(0, less_decimals)
144
more_digits(1, decimals_left)
147
range_start = '1' if top_level else '0'
149
digit_range(range_start, chr(ord(c) - 1))
151
more_digits(1, less_decimals)
155
more_digits(0, less_decimals)
157
min_s = str(min_value)
162
digit_range("1" if top_level else "0", chr(ord(c) - 1))
164
more_digits(length, less_decimals)
168
_generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
172
digit_range(chr(ord(c) + 1), "9")
174
more_digits(length - 1, less_decimals)
180
out.append("\"-\" [1-9] ")
181
more_digits(0, less_decimals)
183
_generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
185
out.append("\"-\" (")
186
_generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
190
raise RuntimeError("At least one of min_value or max_value must be set")
193
def __init__(self, content: str, deps: list | None = None):
194
self.content = content
195
self.deps = deps or []
197
# Constraining spaces to prevent model "running away".
198
SPACE_RULE = '| " " | "\\n" [ \\t]{0,20}'
201
'boolean' : BuiltinRule('("true" | "false") space', []),
202
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
203
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
204
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
205
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
206
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
207
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
208
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
209
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
210
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
211
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
212
'null' : BuiltinRule('"null" space', []),
215
# TODO: support "uri", "email" string formats
216
STRING_FORMAT_RULES = {
217
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
218
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
219
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
220
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
221
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
222
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
225
DOTALL = '[\\U00000000-\\U0010FFFF]'
228
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
230
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
231
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
232
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
233
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
235
NON_LITERAL_SET = set('|.()[]{}*+?')
236
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
239
class SchemaConverter:
240
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
241
self._prop_order = prop_order
242
self._allow_fetch = allow_fetch
243
self._dotall = dotall
244
self._raw_pattern = raw_pattern
249
self._refs_being_resolved = set()
251
def _format_literal(self, literal):
252
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
253
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
255
return f'"{escaped}"'
257
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
259
not_literal('a') -> '[^a]'
260
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
262
assert len(literal) > 0, 'Empty literal not supported'
265
if maybe_escaped_underscores and c == '_':
268
yield f'"\\\\"? "{c}"'
271
if i < len(literal) - 1:
273
yield self._format_literal(c)
275
yield from recurse(i + 1)
278
return ''.join(('(', *recurse(0), ')'))
280
def _not_strings(self, strings):
284
self.is_end_of_string = False
286
def insert(self, string):
289
node = node.children.setdefault(c, TrieNode())
290
node.is_end_of_string = True
296
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
302
for c in sorted(node.children.keys()):
303
child = node.children[c]
314
elif child.is_end_of_string:
315
out.append(f' {char_rule}+')
319
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
322
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
325
def _add_rule(self, name, rule):
326
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
327
if esc_name not in self._rules or self._rules[esc_name] == rule:
331
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
333
key = f'{esc_name}{i}'
334
self._rules[key] = rule
337
def resolve_refs(self, schema: dict, url: str):
339
Resolves all $ref fields in the given schema, fetching any remote schemas,
340
replacing $ref with absolute reference URL and populating self._refs with the
341
respective referenced (sub)schema dictionaries.
344
if isinstance(n, list):
345
return [visit(x) for x in n]
346
elif isinstance(n, dict):
348
if ref is not None and ref not in self._refs:
349
if ref.startswith('https://'):
350
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
353
frag_split = ref.split('#')
354
base_url = frag_split[0]
356
target = self._refs.get(base_url)
358
target = self.resolve_refs(requests.get(ref).json(), base_url)
359
self._refs[base_url] = target
361
if len(frag_split) == 1 or frag_split[-1] == '':
363
elif ref.startswith('#/'):
368
raise ValueError(f'Unsupported ref {ref}')
370
for sel in ref.split('#')[-1].split('/')[1:]:
371
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
374
self._refs[ref] = target
382
def _generate_union_rule(self, name, alt_schemas):
384
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
385
for i, alt_schema in enumerate(alt_schemas)
388
def _visit_pattern(self, pattern, name):
390
Transforms a regular expression pattern into a GBNF rule.
392
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
393
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
395
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
397
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
398
we define sub-rules to keep the output lean.
401
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
402
pattern = pattern[1:-1]
406
length = len(pattern)
408
def to_rule(s: tuple[str, bool]) -> str:
409
(txt, is_literal) = s
410
return "\"" + txt + "\"" if is_literal else txt
412
def transform() -> tuple[str, bool]:
414
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
418
nonlocal sub_rule_ids
421
# For each component of this sequence, store its string representation and whether it's a literal.
422
# We only need a flat structure here to apply repetition operators to the last item, and
423
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
424
# (GBNF's syntax is luckily very close to regular expressions!)
425
seq: list[tuple[str, bool]] = []
431
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
433
return self._add_rule(f'dot', rule)
438
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
440
ret.append((''.join(x[0] for x in g), True))
445
return (' '.join(to_rule(x) for x in seq), False)
450
seq.append((get_dot(), False))
455
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
456
seq.append((f'({to_rule(transform())})', False))
459
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
464
while i < length and pattern[i] != ']':
465
if pattern[i] == '\\':
466
square_brackets += pattern[i:i+2]
469
square_brackets += pattern[i]
471
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
472
square_brackets += ']'
474
seq.append((square_brackets, False))
476
seq.append(('|', False))
478
elif c in ('*', '+', '?'):
479
seq[-1] = (to_rule(seq[-1]) + c, False)
484
while i < length and pattern[i] != '}':
485
curly_brackets += pattern[i]
487
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
488
curly_brackets += '}'
490
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
495
min_times = int(nums[0])
496
max_times = min_times
498
assert len(nums) == 2
499
min_times = int(nums[0]) if nums[0] else 0
500
max_times = int(nums[1]) if nums[1] else None
502
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
504
(sub, sub_is_literal) = seq[-1]
506
if not sub_is_literal:
507
id = sub_rule_ids.get(sub)
509
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
510
sub_rule_ids[sub] = id
513
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
517
if pattern[i] == '\\' and i < length - 1:
518
next = pattern[i + 1]
519
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
521
literal += pattern[i]
524
literal += pattern[i:i+2]
526
elif pattern[i] == '"' and not self._raw_pattern:
529
elif pattern[i] not in NON_LITERAL_SET and \
530
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
531
literal += pattern[i]
536
seq.append((literal, True))
540
return self._add_rule(
542
to_rule(transform()) if self._raw_pattern \
543
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
546
def _resolve_ref(self, ref):
547
ref_name = ref.split('/')[-1]
548
if ref_name not in self._rules and ref not in self._refs_being_resolved:
549
self._refs_being_resolved.add(ref)
550
resolved = self._refs[ref]
551
ref_name = self.visit(resolved, ref_name)
552
self._refs_being_resolved.remove(ref)
555
def _generate_constant_rule(self, value):
556
return self._format_literal(json.dumps(value))
558
def visit(self, schema, name):
559
schema_type = schema.get('type')
560
schema_format = schema.get('format')
561
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
563
if (ref := schema.get('$ref')) is not None:
564
return self._add_rule(rule_name, self._resolve_ref(ref))
566
elif 'oneOf' in schema or 'anyOf' in schema:
567
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
569
elif isinstance(schema_type, list):
570
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
572
elif 'const' in schema:
573
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
575
elif 'enum' in schema:
576
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
577
return self._add_rule(rule_name, rule)
579
elif schema_type in (None, 'object') and \
580
('properties' in schema or \
581
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
582
required = set(schema.get('required', []))
583
properties = list(schema.get('properties', {}).items())
584
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
586
elif schema_type in (None, 'object') and 'allOf' in schema:
590
def add_component(comp_schema, is_required):
591
if (ref := comp_schema.get('$ref')) is not None:
592
comp_schema = self._refs[ref]
594
if 'properties' in comp_schema:
595
for prop_name, prop_schema in comp_schema['properties'].items():
596
properties.append((prop_name, prop_schema))
598
required.add(prop_name)
600
for t in schema['allOf']:
602
for tt in t['anyOf']:
603
add_component(tt, is_required=False)
605
add_component(t, is_required=True)
607
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
609
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
610
items = schema.get('items') or schema['prefixItems']
611
if isinstance(items, list):
612
return self._add_rule(
616
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
617
for i, item in enumerate(items)) +
620
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
621
min_items = schema.get("minItems", 0)
622
max_items = schema.get("maxItems")
623
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
625
elif schema_type in (None, 'string') and 'pattern' in schema:
626
return self._visit_pattern(schema['pattern'], rule_name)
628
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
629
return self._add_primitive(
630
'root' if rule_name == 'root' else schema_format,
631
PRIMITIVE_RULES['uuid']
634
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
635
prim_name = f'{schema_format}-string'
636
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
638
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
639
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
640
min_len = schema.get('minLength', 0)
641
max_len = schema.get('maxLength')
643
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
645
elif schema_type in (None, 'integer') and \
646
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
649
if 'minimum' in schema:
650
min_value = schema['minimum']
651
elif 'exclusiveMinimum' in schema:
652
min_value = schema['exclusiveMinimum'] + 1
653
if 'maximum' in schema:
654
max_value = schema['maximum']
655
elif 'exclusiveMaximum' in schema:
656
max_value = schema['exclusiveMaximum'] - 1
659
_generate_min_max_int(min_value, max_value, out)
660
out.append(") space")
661
return self._add_rule(rule_name, ''.join(out))
663
elif (schema_type == 'object') or (len(schema) == 0):
664
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
667
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
668
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
669
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
671
def _add_primitive(self, name: str, rule: BuiltinRule):
672
n = self._add_rule(name, rule.content)
674
for dep in rule.deps:
675
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
676
assert dep_rule, f'Rule {dep} not known'
677
if dep not in self._rules:
678
self._add_primitive(dep, dep_rule)
681
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
682
prop_order = self._prop_order
683
# sort by position in prop_order (if specified) then by original order
684
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
686
prop_kv_rule_names = {}
687
for prop_name, prop_schema in properties:
688
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
689
prop_kv_rule_names[prop_name] = self._add_rule(
690
f'{name}{"-" if name else ""}{prop_name}-kv',
691
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
693
required_props = [k for k in sorted_props if k in required]
694
optional_props = [k for k in sorted_props if k not in required]
696
if additional_properties is not None and additional_properties != False:
697
sub_name = f'{name}{"-" if name else ""}additional'
698
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
699
self._add_primitive('value', PRIMITIVE_RULES['value'])
700
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
701
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
703
prop_kv_rule_names["*"] = self._add_rule(
705
f'{key_rule} ":" space {value_rule}'
707
optional_props.append("*")
710
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
715
rule += ' "," space ( '
717
def get_recursive_refs(ks, first_is_optional):
719
kv_rule_name = prop_kv_rule_names[k]
720
comma_ref = f'( "," space {kv_rule_name} )'
721
if first_is_optional:
722
res = comma_ref + ('*' if k == '*' else '?')
724
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
726
res += ' ' + self._add_rule(
727
f'{name}{"-" if name else ""}{k}-rest',
728
get_recursive_refs(rest, first_is_optional=True)
733
get_recursive_refs(optional_props[i:], first_is_optional=False)
734
for i in range(len(optional_props))
744
def format_grammar(self):
747
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
751
def main(args_in = None):
752
parser = argparse.ArgumentParser(
754
Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
755
given JSON schema. Only a subset of JSON schema features are supported; more may be
762
type=lambda s: s.split(','),
764
comma-separated property names defining the order of precedence for object properties;
765
properties not specified here are given lower precedence than those that are, and
766
are kept in their original order from the schema. Required properties are always
767
given precedence over optional properties.
774
help='Whether to allow fetching referenced schemas over HTTPS')
779
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
784
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
786
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
787
args = parser.parse_args(args_in)
789
if args.schema.startswith('https://'):
792
schema = requests.get(url).json()
793
elif args.schema == '-':
795
schema = json.load(sys.stdin)
797
url = f'file://{args.schema}'
798
with open(args.schema) as f:
799
schema = json.load(f)
800
converter = SchemaConverter(
801
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
802
allow_fetch=args.allow_fetch,
804
raw_pattern=args.raw_pattern)
805
schema = converter.resolve_refs(schema, url)
806
converter.visit(schema, '')
807
print(converter.format_grammar())
810
if __name__ == '__main__':