pytorch

Форк
0
/
gen_vulkan_spv.py 
768 строк · 24.2 Кб
1
#!/usr/bin/env python3
2

3
from __future__ import annotations
4

5
import argparse
6
import array
7
import codecs
8
import copy
9
import glob
10
import io
11
import os
12
import re
13
import sys
14
from itertools import product
15

16
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
17
import subprocess
18
import textwrap
19
from dataclasses import dataclass
20
from typing import Any
21

22
import yaml
23
from yaml.constructor import ConstructorError
24
from yaml.nodes import MappingNode
25

26
try:
27
    from yaml import CLoader as Loader
28
except ImportError:
29
    from yaml import Loader  # type: ignore[assignment, misc]
30

31
CPP_H_NAME = "spv.h"
32
CPP_SRC_NAME = "spv.cpp"
33

34
DEFAULT_ENV: dict[str, Any] = {
35
    "PRECISION": "highp",
36
    "FLOAT_IMAGE_FORMAT": "rgba16f",
37
    "INT_IMAGE_FORMAT": "rgba32i",
38
    "UINT_IMAGE_FORMAT": "rgba32ui",
39
}
40

41
TYPES_ENV: dict[str, Any] = {
42
    "IMAGE_FORMAT": {
43
        "float": "rgba32f",
44
        "half": "rgba16f",
45
        "int": "rgba32i",
46
        "uint": "rgba32ui",
47
        "int8": "rgba8i",
48
        "uint8": "rgba8ui",
49
    },
50
    "IMAGE_T": {
51
        3: {
52
            "float": "image3D",
53
            "half": "image3D",
54
            "int": "iimage3D",
55
            "uint": "uimage3D",
56
        },
57
        2: {
58
            "float": "image2D",
59
            "half": "image2D",
60
            "int": "iimage2D",
61
            "uint": "uimage2D",
62
        },
63
    },
64
    "SAMPLER_T": {
65
        3: {
66
            "float": "sampler3D",
67
            "half": "sampler3D",
68
            "int": "isampler3D",
69
            "uint": "usampler3D",
70
        },
71
        2: {
72
            "float": "sampler2D",
73
            "half": "sampler2D",
74
            "int": "isampler2D",
75
            "uint": "usampler2D",
76
        },
77
    },
78
    "VEC4_T": {
79
        "float": "vec4",
80
        "half": "vec4",
81
        "int": "ivec4",
82
        "uint": "uvec4",
83
        "int8": "vec4",
84
        "uint8": "uvec4",
85
    },
86
    "T": {
87
        "float": "float",
88
        "half": "float",
89
        "int": "int",
90
        "uint": "uint",
91
        "int8": "int",
92
        "uint8": "uint8",
93
    },
94
}
95

96
FUNCS_ENV: dict[str, Any] = {
97
    "GET_POS": {
98
        3: lambda pos: pos,
99
        2: lambda pos: f"{pos}.xy",
100
    }
101
}
102

103

104
def extract_filename(path: str, keep_ext: bool = True) -> Any:
105
    if keep_ext:
106
        return os.path.basename(path)
107
    else:
108
        return os.path.basename(path).split(".")[0]
109

110

111
############################
112
#  SPIR-V Code Generation  #
113
############################
114

115

116
# https://gist.github.com/pypt/94d747fe5180851196eb
117
class UniqueKeyLoader(Loader):
118
    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
119
        if not isinstance(node, MappingNode):
120
            raise ConstructorError(
121
                None,
122
                None,
123
                f"expected a mapping node, but found {node.id}",
124
                node.start_mark,
125
            )
126
        mapping = {}
127
        for key_node, value_node in node.value:
128
            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
129
            try:
130
                hash(key)
131
            except TypeError as e:
132
                raise ConstructorError(
133
                    "while constructing a mapping",
134
                    node.start_mark,
135
                    "found unacceptable key ",
136
                    key_node.start_mark,
137
                ) from e
138
            # check for duplicate keys
139
            if key in mapping:
140
                raise ConstructorError(
141
                    "while constructing a mapping",
142
                    node.start_mark,
143
                    "found duplicate key",
144
                    key_node.start_mark,
145
                )
146
            value = self.construct_object(value_node, deep=deep)  # type: ignore[no-untyped-call]
147
            mapping[key] = value
148
        return mapping
149

150

151
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
152
def extract_leading_whitespace(line: str) -> str:
153
    match = re.match(r"\s*", line)
154
    return match.group(0) if match else ""
155

156

157
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
158
def escape(line: str) -> str:
159
    output_parts = []
160
    while "${" in line:
161
        start_pos = line.index("${")
162
        end_pos = line.index("}", start_pos + 2)
163
        if start_pos != 0:
164
            output_parts.append('"' + line[:start_pos].replace('"', '\\"') + '"')
165
        output_parts.append("str(" + line[start_pos + 2 : end_pos] + ")")
166
        line = line[end_pos + 1 :]
167
    if line:
168
        output_parts.append('"' + line.replace('"', '\\"') + '"')
169
    return " + ".join(output_parts)
170

171

172
# https://github.com/google/XNNPACK/blob/master/tools/xngen.py
173
def preprocess(
174
    input_text: str, variables: dict[str, Any], input_path: str = "codegen"
175
) -> str:
176
    input_lines = input_text.splitlines()
177
    python_lines = []
178

179
    blank_lines = 0
180

181
    last_indent = ""
182

183
    # List of tuples (total_index, python_indent)
184
    indent_stack = [("", "")]
185

186
    # Indicates whether this is the first line inside Python
187
    # code block (i.e. for, while, if, elif, else)
188
    python_block_start = True
189
    for i, input_line in enumerate(input_lines):
190
        if input_line == "":
191
            blank_lines += 1
192
            continue
193
        # Skip lint markers.
194
        if "LINT" in input_line:
195
            continue
196

197
        input_indent = extract_leading_whitespace(input_line)
198
        if python_block_start:
199
            assert input_indent.startswith(last_indent)
200
            extra_python_indent = input_indent[len(last_indent) :]
201
            python_indent = indent_stack[-1][1] + extra_python_indent
202
            indent_stack.append((input_indent, python_indent))
203
            assert input_indent.startswith(indent_stack[-1][0])
204
        else:
205
            while not input_indent.startswith(indent_stack[-1][0]):
206
                del indent_stack[-1]
207
        python_block_start = False
208

209
        python_indent = indent_stack[-1][1]
210
        stripped_input_line = input_line.strip()
211
        if stripped_input_line.startswith("$") and not stripped_input_line.startswith(
212
            "${"
213
        ):
214
            if stripped_input_line.endswith(":"):
215
                python_block_start = True
216
            while blank_lines != 0:
217
                python_lines.append(python_indent + "print(file=OUT_STREAM)")
218
                blank_lines -= 1
219
            python_lines.append(python_indent + stripped_input_line.replace("$", ""))
220
        else:
221
            assert input_line.startswith(python_indent)
222
            while blank_lines != 0:
223
                python_lines.append(python_indent + "print(file=OUT_STREAM)")
224
                blank_lines -= 1
225
            python_lines.append(
226
                python_indent
227
                + f"print({escape(input_line[len(python_indent) :])}, file=OUT_STREAM)"
228
            )
229
        last_indent = input_indent
230

231
    while blank_lines != 0:
232
        python_lines.append(python_indent + "print(file=OUT_STREAM)")
233
        blank_lines -= 1
234

235
    exec_globals = dict(variables)
236
    output_stream = io.StringIO()
237
    exec_globals["OUT_STREAM"] = output_stream
238

239
    python_bytecode = compile("\n".join(python_lines), input_path, "exec")
240
    exec(python_bytecode, exec_globals)
241

242
    return output_stream.getvalue()
243

244

245
class SPVGenerator:
246
    def __init__(
247
        self,
248
        src_dir_paths: str | list[str],
249
        env: dict[Any, Any],
250
        glslc_path: str | None,
251
    ) -> None:
252
        if isinstance(src_dir_paths, str):
253
            self.src_dir_paths = [src_dir_paths]
254
        else:
255
            self.src_dir_paths = src_dir_paths
256

257
        self.env = env
258
        self.glslc_path = glslc_path
259

260
        self.glsl_src_files: dict[str, str] = {}
261
        self.template_yaml_files: list[str] = []
262

263
        self.addSrcAndYamlFiles(self.src_dir_paths)
264
        self.shader_template_params: dict[Any, Any] = {}
265
        for yaml_file in self.template_yaml_files:
266
            self.parseTemplateYaml(yaml_file)
267

268
        self.output_shader_map: dict[str, tuple[str, dict[str, str]]] = {}
269
        self.constructOutputMap()
270

271
    def addSrcAndYamlFiles(self, src_dir_paths: list[str]) -> None:
272
        for src_path in src_dir_paths:
273
            # Collect glsl source files
274
            glsl_files = glob.glob(
275
                os.path.join(src_path, "**", "*.glsl*"), recursive=True
276
            )
277
            for file in glsl_files:
278
                if len(file) > 1:
279
                    self.glsl_src_files[extract_filename(file, keep_ext=False)] = file
280
            # Collect template yaml files
281
            yaml_files = glob.glob(
282
                os.path.join(src_path, "**", "*.yaml"), recursive=True
283
            )
284
            for file in yaml_files:
285
                if len(file) > 1:
286
                    self.template_yaml_files.append(file)
287

288
    def generateVariantCombinations(
289
        self,
290
        iterated_params: dict[str, Any],
291
        exclude_params: set[str] | None = None,
292
    ) -> list[Any]:
293
        if exclude_params is None:
294
            exclude_params = set()
295
        all_iterated_params = []
296
        for param_name, value_list in iterated_params.items():
297
            if param_name not in exclude_params:
298
                param_values = []
299
                for value in value_list:
300
                    suffix = value.get("SUFFIX", value["VALUE"])
301
                    param_values.append((param_name, suffix, value["VALUE"]))
302
                all_iterated_params.append(param_values)
303

304
        return list(product(*all_iterated_params))
305

306
    def parseTemplateYaml(self, yaml_file: str) -> None:
307
        with open(yaml_file) as f:
308
            contents = yaml.load(f, Loader=UniqueKeyLoader)
309
            for template_name, params_dict in contents.items():
310
                if template_name in self.shader_template_params:
311
                    raise KeyError(f"{template_name} params file is defined twice")
312

313
                default_params = params_dict["parameter_names_with_default_values"]
314
                params_names = set(default_params.keys()).union({"NAME"})
315

316
                self.shader_template_params[template_name] = []
317

318
                default_iterated_params = params_dict.get(
319
                    "generate_variant_forall", None
320
                )
321

322
                for variant in params_dict["shader_variants"]:
323
                    variant_params_names = set(variant.keys())
324
                    invalid_keys = (
325
                        variant_params_names
326
                        - params_names
327
                        - {"generate_variant_forall"}
328
                    )
329
                    assert len(invalid_keys) == 0
330

331
                    iterated_params = variant.get(
332
                        "generate_variant_forall", default_iterated_params
333
                    )
334

335
                    if iterated_params is not None:
336
                        variant_combinations = self.generateVariantCombinations(
337
                            iterated_params, variant_params_names
338
                        )
339

340
                        for combination in variant_combinations:
341
                            default_params_copy = copy.deepcopy(default_params)
342
                            for key in variant:
343
                                if key != "generate_variant_forall":
344
                                    default_params_copy[key] = variant[key]
345

346
                            variant_name = variant["NAME"]
347
                            for param_value in combination:
348
                                default_params_copy[param_value[0]] = param_value[2]
349
                                if len(param_value[1]) > 0:
350
                                    variant_name = f"{variant_name}_{param_value[1]}"
351

352
                            default_params_copy["NAME"] = variant_name
353

354
                            self.shader_template_params[template_name].append(
355
                                default_params_copy
356
                            )
357
                    else:
358
                        default_params_copy = copy.deepcopy(default_params)
359
                        for key in variant:
360
                            default_params_copy[key] = variant[key]
361

362
                        self.shader_template_params[template_name].append(
363
                            default_params_copy
364
                        )
365

366
    def create_shader_params(
367
        self, variant_params: dict[str, Any] | None = None
368
    ) -> dict[str, str]:
369
        if variant_params is None:
370
            variant_params = {}
371
        shader_params = copy.deepcopy(self.env)
372
        for key, value in variant_params.items():
373
            shader_params[key] = value
374

375
        shader_dtype = shader_params.get("DTYPE", "float")
376

377
        if shader_dtype == "int":
378
            shader_params["FORMAT"] = self.env["INT_IMAGE_FORMAT"]
379
        elif shader_dtype == "uint":
380
            shader_params["FORMAT"] = self.env["UINT_IMAGE_FORMAT"]
381
        elif shader_dtype == "int32":
382
            shader_params["FORMAT"] = "rgba32i"
383
        elif shader_dtype == "uint32":
384
            shader_params["FORMAT"] = "rgba32ui"
385
        elif shader_dtype == "int8":
386
            shader_params["FORMAT"] = "rgba8i"
387
        elif shader_dtype == "uint8":
388
            shader_params["FORMAT"] = "rgba8ui"
389
        elif shader_dtype == "float32":
390
            shader_params["FORMAT"] = "rgba32f"
391
        # Assume float by default
392
        else:
393
            shader_params["FORMAT"] = self.env["FLOAT_IMAGE_FORMAT"]
394

395
        return shader_params
396

397
    def constructOutputMap(self) -> None:
398
        for shader_name, params in self.shader_template_params.items():
399
            for variant in params:
400
                source_glsl = self.glsl_src_files[shader_name]
401

402
                self.output_shader_map[variant["NAME"]] = (
403
                    source_glsl,
404
                    self.create_shader_params(variant),
405
                )
406

407
        for shader_name, source_glsl in self.glsl_src_files.items():
408
            if shader_name not in self.shader_template_params:
409
                self.output_shader_map[shader_name] = (
410
                    source_glsl,
411
                    self.create_shader_params(),
412
                )
413

414
    def generateSPV(self, output_dir: str) -> dict[str, str]:
415
        output_file_map = {}
416
        for shader_name in self.output_shader_map:
417
            source_glsl = self.output_shader_map[shader_name][0]
418
            shader_params = self.output_shader_map[shader_name][1]
419

420
            with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
421
                input_text = input_file.read()
422
                output_text = preprocess(input_text, shader_params)
423

424
            glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
425
            with codecs.open(glsl_out_path, "w", encoding="utf-8") as output_file:
426
                output_file.write(output_text)
427

428
            # If no GLSL compiler is specified, then only write out the generated GLSL shaders.
429
            # This is mainly for testing purposes.
430
            if self.glslc_path is not None:
431
                spv_out_path = os.path.join(output_dir, f"{shader_name}.spv")
432

433
                cmd = [
434
                    self.glslc_path,
435
                    "-fshader-stage=compute",
436
                    glsl_out_path,
437
                    "-o",
438
                    spv_out_path,
439
                    "--target-env=vulkan1.0",
440
                    "-Werror",
441
                ] + [
442
                    arg
443
                    for src_dir_path in self.src_dir_paths
444
                    for arg in ["-I", src_dir_path]
445
                ]
446

447
                print("glslc cmd:", cmd)
448
                subprocess.check_call(cmd)
449

450
                output_file_map[spv_out_path] = glsl_out_path
451

452
        return output_file_map
453

454

455
##############################################
456
#  Shader Info and Shader Registry Handling  #
457
##############################################
458

459

460
@dataclass
461
class ShaderInfo:
462
    tile_size: list[int]
463
    layouts: list[str]
464
    weight_storage_type: str = ""
465
    bias_storage_type: str = ""
466
    register_for: tuple[str, list[str]] | None = None
467

468

469
def getName(filePath: str) -> str:
470
    return os.path.basename(filePath).replace("/", "_").replace(".", "_")
471

472

473
def isDescriptorLine(lineStr: str) -> bool:
474
    descriptorLineId = r"^layout\(set"
475
    return re.search(descriptorLineId, lineStr) is not None
476

477

478
def isTileSizeLine(lineStr: str) -> bool:
479
    tile_size_id = r"^ \* TILE_SIZE = \("
480
    return re.search(tile_size_id, lineStr) is not None
481

482

483
def findTileSizes(lineStr: str) -> list[int]:
484
    tile_size_id = r"^ \* TILE_SIZE = \(([0-9]+), ([0-9]+), ([0-9]+)\)"
485
    matches = re.search(tile_size_id, lineStr)
486
    if matches is None:
487
        raise AssertionError("matches is None in findTileSizes")
488
    return [int(matches.group(1)), int(matches.group(2)), int(matches.group(3))]
489

490

491
def isWeightStorageTypeLine(lineStr: str) -> bool:
492
    weight_storage_id = r"^ \* WEIGHT_STORAGE = "
493
    return re.search(weight_storage_id, lineStr) is not None
494

495

496
def getWeightStorageType(lineStr: str) -> str:
497
    weight_storage_id = r"^ \* WEIGHT_STORAGE = ([a-zA-Z]+_\dD)"
498
    matches = re.search(weight_storage_id, lineStr)
499
    if matches is None:
500
        raise AssertionError("matches is None in getWeightStorageType")
501
    return matches.group(1)
502

503

504
def isBiasStorageTypeLine(lineStr: str) -> bool:
505
    weight_storage_id = r"^ \* BIAS_STORAGE = "
506
    return re.search(weight_storage_id, lineStr) is not None
507

508

509
def getBiasStorageType(lineStr: str) -> str:
510
    weight_storage_id = r"^ \* BIAS_STORAGE = ([a-zA-Z]+_\dD)"
511
    matches = re.search(weight_storage_id, lineStr)
512
    if matches is None:
513
        raise AssertionError("matches is None in getBiasStorageType")
514
    return matches.group(1)
515

516

517
def isRegisterForLine(lineStr: str) -> bool:
518
    # Check for Shader Name and a list of at least one Registry Key
519
    register_for_id = (
520
        r"^ \* REGISTER_FOR = \('([A-Za-z0-9_]+)'\s*,\s*\['([A-Za-z0-9_]+)'.*\]\)"
521
    )
522
    return re.search(register_for_id, lineStr) is not None
523

524

525
def findRegisterFor(lineStr: str) -> tuple[str, list[str]]:
526
    register_for_pattern = r"'([A-Za-z0-9_]+)'"
527
    matches = re.findall(register_for_pattern, lineStr)
528
    if matches is None:
529
        raise AssertionError("matches is None in getBiasStorageType")
530
    matches_list = list(matches)
531
    return (matches_list[0], matches_list[1:])
532

533

534
typeIdMapping = {
535
    r"image[123]D\b": "VK_DESCRIPTOR_TYPE_STORAGE_IMAGE",
536
    r"sampler[123]D\b": "VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER",
537
    r"\bbuffer\b": "VK_DESCRIPTOR_TYPE_STORAGE_BUFFER",
538
    r"\buniform\b": "VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER",
539
}
540

541
storageTypeToEnum = {
542
    "TEXTURE_2D": "api::StorageType::TEXTURE_2D",
543
    "TEXTURE_3D": "api::StorageType::TEXTURE_3D",
544
    "BUFFER": "api::StorageType::BUFFER",
545
    "": "api::StorageType::UNKNOWN",
546
}
547

548

549
def determineDescriptorType(lineStr: str) -> str:
550
    for identifier, typeNum in typeIdMapping.items():
551
        if re.search(identifier, lineStr):
552
            return typeNum
553
    raise AssertionError(
554
        "No matching descriptor type for " + lineStr + " in determineDescriptorType"
555
    )
556

557

558
def getShaderInfo(srcFilePath: str) -> ShaderInfo:
559
    shader_info = ShaderInfo([], [], "")
560
    with open(srcFilePath) as srcFile:
561
        for line in srcFile:
562
            if isDescriptorLine(line):
563
                shader_info.layouts.append(determineDescriptorType(line))
564
            if isTileSizeLine(line):
565
                shader_info.tile_size = findTileSizes(line)
566
            if isWeightStorageTypeLine(line):
567
                shader_info.weight_storage_type = getWeightStorageType(line)
568
            if isBiasStorageTypeLine(line):
569
                shader_info.bias_storage_type = getBiasStorageType(line)
570
            if isRegisterForLine(line):
571
                shader_info.register_for = findRegisterFor(line)
572

573
    return shader_info
574

575

576
##########################
577
#  C++ File Generation  #
578
#########################
579

580
cpp_template = """
581
#include <ATen/native/vulkan/api/ShaderRegistry.h>
582
#include <stdint.h>
583
#include <vector>
584

585
using namespace at::native::vulkan;
586

587
namespace at {{
588
namespace native {{
589
namespace vulkan {{
590

591
namespace {{
592

593
{spv_bin_arrays}
594

595
}}
596

597
static void register_fn() {{
598

599
{register_shader_infos}
600

601
{shader_info_registry}
602

603
}}
604

605
static const api::ShaderRegisterInit register_shaders(&register_fn);
606

607
}}
608
}}
609
}}
610

611
"""
612

613

614
def generateSpvBinStr(spvPath: str, name: str) -> tuple[int, str]:
615
    with open(spvPath, "rb") as fr:
616
        next_bin = array.array("I", fr.read())
617
        sizeBytes = 4 * len(next_bin)
618
        spv_bin_str = "const uint32_t {}_bin[] = {{\n{}\n}};".format(
619
            name,
620
            textwrap.indent(",\n".join(str(x) for x in next_bin), "  "),
621
        )
622

623
    return sizeBytes, spv_bin_str
624

625

626
def generateShaderInfoStr(shader_info: ShaderInfo, name: str, sizeBytes: int) -> str:
627
    tile_size = (
628
        f"{{{', '.join(str(x) for x in shader_info.tile_size)}}}"
629
        if (len(shader_info.tile_size) > 0)
630
        else "std::vector<uint32_t>()"
631
    )
632

633
    shader_info_layouts = "{{{}}}".format(",\n ".join(shader_info.layouts))
634

635
    shader_info_args = [
636
        f'"{name}"',
637
        f"{name}_bin",
638
        str(sizeBytes),
639
        shader_info_layouts,
640
        tile_size,
641
        storageTypeToEnum[shader_info.weight_storage_type],
642
        storageTypeToEnum[shader_info.bias_storage_type],
643
    ]
644

645
    shader_info_str = textwrap.indent(
646
        "api::shader_registry().register_shader(\n  api::ShaderInfo(\n{args}));\n".format(
647
            args=textwrap.indent(",\n".join(shader_info_args), "     "),
648
        ),
649
        "    ",
650
    )
651

652
    return shader_info_str
653

654

655
def generateShaderDispatchStr(shader_info: ShaderInfo, name: str) -> str:
656
    if shader_info.register_for is None:
657
        return ""
658

659
    (op_name, registry_keys) = shader_info.register_for
660
    for registry_key in registry_keys:
661
        shader_dispatch_str = textwrap.indent(
662
            f'api::shader_registry().register_op_dispatch("{op_name}", api::DispatchKey::{registry_key.upper()}, "{name}");',
663
            "    ",
664
        )
665

666
    return shader_dispatch_str
667

668

669
def genCppFiles(
670
    spv_files: dict[str, str], cpp_header_path: str, cpp_src_file_path: str
671
) -> None:
672
    spv_bin_strs = []
673
    register_shader_info_strs = []
674
    shader_registry_strs = []
675

676
    for spvPath, srcPath in spv_files.items():
677
        name = getName(spvPath).replace("_spv", "")
678

679
        sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name)
680
        spv_bin_strs.append(spv_bin_str)
681

682
        shader_info = getShaderInfo(srcPath)
683

684
        register_shader_info_strs.append(
685
            generateShaderInfoStr(shader_info, name, sizeBytes)
686
        )
687

688
        if shader_info.register_for is not None:
689
            shader_registry_strs.append(generateShaderDispatchStr(shader_info, name))
690

691
    spv_bin_arrays = "\n".join(spv_bin_strs)
692
    register_shader_infos = "\n".join(register_shader_info_strs)
693
    shader_info_registry = "\n".join(shader_registry_strs)
694

695
    cpp = cpp_template.format(
696
        spv_bin_arrays=spv_bin_arrays,
697
        register_shader_infos=register_shader_infos,
698
        shader_info_registry=shader_info_registry,
699
    )
700

701
    with open(cpp_src_file_path, "w") as fw:
702
        fw.write(cpp)
703

704

705
##########
706
#  Main  #
707
##########
708

709

710
def parse_arg_env(items: dict[Any, Any]) -> dict[Any, Any]:
711
    d = {}
712
    if items:
713
        for item in items:
714
            tokens = item.split("=")
715
            key = tokens[0].strip()
716
            value = tokens[1].strip()
717
            d[key] = value
718
    return d
719

720

721
def main(argv: list[str]) -> int:
722
    parser = argparse.ArgumentParser(description="")
723
    parser.add_argument(
724
        "-i",
725
        "--glsl-paths",
726
        nargs="+",
727
        help='List of paths to look for GLSL source files, separated by spaces. Ex: --glsl-paths "path1 path2 path3"',
728
        default=["."],
729
    )
730
    parser.add_argument("-c", "--glslc-path", required=True, help="")
731
    parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
732
    parser.add_argument("-o", "--output-path", required=True, help="")
733
    parser.add_argument(
734
        "--env", metavar="KEY=VALUE", nargs="*", help="Set a number of key-value pairs"
735
    )
736
    options = parser.parse_args()
737

738
    DEFAULT_ENV.update(TYPES_ENV)
739
    DEFAULT_ENV.update(FUNCS_ENV)
740
    env = DEFAULT_ENV
741

742
    for key, value in parse_arg_env(options.env).items():
743
        env[key] = value
744

745
    if not os.path.exists(options.output_path):
746
        os.makedirs(options.output_path)
747

748
    if not os.path.exists(options.tmp_dir_path):
749
        os.makedirs(options.tmp_dir_path)
750

751
    shader_generator = SPVGenerator(options.glsl_paths, env, options.glslc_path)
752
    output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
753

754
    genCppFiles(
755
        output_spv_files,
756
        f"{options.output_path}/{CPP_H_NAME}",
757
        f"{options.output_path}/{CPP_SRC_NAME}",
758
    )
759

760
    return 0
761

762

763
def invoke_main() -> None:
764
    sys.exit(main(sys.argv))
765

766

767
if __name__ == "__main__":
768
    invoke_main()  # pragma: no cover
769

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.