google-research

Форк
0
238 строк · 7.7 Кб
1
"""Setup TensorFlow as external dependency"""
2

3
_TF_HEADER_DIR = "TF_HEADER_DIR"
4
_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
5
_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
6
_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"
7

8
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
9
    if not out:
10
        out = tpl
11
    repository_ctx.template(
12
        out,
13
        Label("//build_deps/tf_dependency:%s.tpl" % tpl),
14
        substitutions,
15
    )
16

17
def _fail(msg):
18
    """Output failure message when auto configuration fails."""
19
    red = "\033[0;31m"
20
    no_color = "\033[0m"
21
    fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
22

23
def _is_windows(repository_ctx):
24
    """Returns true if the host operating system is windows."""
25
    os_name = repository_ctx.os.name.lower()
26
    if os_name.find("windows") != -1:
27
        return True
28
    return False
29

30
def _execute(
31
        repository_ctx,
32
        cmdline,
33
        error_msg = None,
34
        error_details = None,
35
        empty_stdout_fine = False):
36
    """Executes an arbitrary shell command.
37

38
    Helper for executes an arbitrary shell command.
39

40
    Args:
41
      repository_ctx: the repository_ctx object.
42
      cmdline: list of strings, the command to execute.
43
      error_msg: string, a summary of the error if the command fails.
44
      error_details: string, details about the error or steps to fix it.
45
      empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
46
        it's an error.
47

48
    Returns:
49
      The result of repository_ctx.execute(cmdline).
50
    """
51
    result = repository_ctx.execute(cmdline)
52
    if result.stderr or not (empty_stdout_fine or result.stdout):
53
        _fail("\n".join([
54
            error_msg.strip() if error_msg else "Repository command failed",
55
            result.stderr.strip(),
56
            error_details if error_details else "",
57
        ]))
58
    return result
59

60
def _read_dir(repository_ctx, src_dir):
61
    """Returns a string with all files in a directory.
62

63
    Finds all files inside a directory, traversing subfolders and following
64
    symlinks. The returned string contains the full path of all files
65
    separated by line breaks.
66

67
    Args:
68
        repository_ctx: the repository_ctx object.
69
        src_dir: directory to find files from.
70

71
    Returns:
72
        A string of all files inside the given dir.
73
    """
74
    if _is_windows(repository_ctx):
75
        src_dir = src_dir.replace("/", "\\")
76
        find_result = _execute(
77
            repository_ctx,
78
            ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
79
            empty_stdout_fine = True,
80
        )
81

82
        # src_files will be used in genrule.outs where the paths must
83
        # use forward slashes.
84
        result = find_result.stdout.replace("\\", "/")
85
    else:
86
        find_result = _execute(
87
            repository_ctx,
88
            ["find", src_dir, "-follow", "-type", "f"],
89
            empty_stdout_fine = True,
90
        )
91
        result = find_result.stdout
92
    return result
93

94
def _genrule(genrule_name, command, outs):
95
    """Returns a string with a genrule.
96

97
    Genrule executes the given command and produces the given outputs.
98

99
    Args:
100
        genrule_name: A unique name for genrule target.
101
        command: The command to run.
102
        outs: A list of files generated by this rule.
103

104
    Returns:
105
        A genrule target.
106
    """
107
    return (
108
        "genrule(\n" +
109
        '    name = "' +
110
        genrule_name + '",\n' +
111
        "    outs = [\n" +
112
        outs +
113
        "\n    ],\n" +
114
        '    cmd = """\n' +
115
        command +
116
        '\n   """,\n' +
117
        ")\n"
118
    )
119

120
def _norm_path(path):
121
    """Returns a path with '/' and remove the trailing slash."""
122
    path = path.replace("\\", "/")
123
    if path[-1] == "/":
124
        path = path[:-1]
125
    return path
126

127
def _symlink_genrule_for_dir(
128
        repository_ctx,
129
        src_dir,
130
        dest_dir,
131
        genrule_name,
132
        src_files = [],
133
        dest_files = [],
134
        tf_pip_dir_rename_pair = []):
135
    """Returns a genrule to symlink(or copy if on Windows) a set of files.
136

137
    If src_dir is passed, files will be read from the given directory; otherwise
138
    we assume files are in src_files and dest_files.
139
    Args:
140
        repository_ctx: the repository_ctx object.
141
        src_dir: source directory.
142
        dest_dir: directory to create symlink in.
143
        genrule_name: genrule name.
144
        src_files: list of source files instead of src_dir.
145
        dest_files: list of corresonding destination files.
146
        tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
147
          replace. For example, in TF pip package, the source code is under
148
          "tensorflow_core", and we might want to replace it with
149
          "tensorflow" to match the header includes.
150
    Returns:
151
        genrule target that creates the symlinks.
152
    """
153

154
    # Check that tf_pip_dir_rename_pair has the right length
155
    tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
156
    if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
157
        _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
158

159
    if src_dir != None:
160
        src_dir = _norm_path(src_dir)
161
        dest_dir = _norm_path(dest_dir)
162
        files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
163

164
        # Create a list with the src_dir stripped to use for outputs.
165
        if tf_pip_dir_rename_pair_len:
166
            dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
167
        else:
168
            dest_files = files.replace(src_dir, "").splitlines()
169
        src_files = files.splitlines()
170
    command = []
171
    outs = []
172

173
    for i in range(len(dest_files)):
174
        if dest_files[i] != "":
175
            # If we have only one file to link we do not want to use the dest_dir, as
176
            # $(@D) will include the full path to the file.
177
            dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
178

179
            # Copy the headers to create a sandboxable setup.
180
            cmd = "cp -f"
181
            command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
182
            outs.append('        "' + dest_dir + dest_files[i] + '",')
183

184
    genrule = _genrule(
185
        genrule_name,
186
        " && ".join(command),
187
        "\n".join(outs),
188
    )
189
    return genrule
190

191
def _tf_pip_impl(repository_ctx):
192
    tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
193
    tf_header_rule = _symlink_genrule_for_dir(
194
        repository_ctx,
195
        tf_header_dir,
196
        "include",
197
        "tf_header_include",
198
        tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
199
    )
200

201
    tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
202
    tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
203
    tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
204
    tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
205

206
    tf_shared_library_rule = _symlink_genrule_for_dir(
207
        repository_ctx,
208
        None,
209
        "",
210
        tf_shared_library_name,
211
        [tf_shared_library_path],
212
        [tf_shared_library_name],
213
    )
214

215
    _tpl(repository_ctx, "BUILD", {
216
        "%{TF_HEADER_GENRULE}": tf_header_rule,
217
        "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
218
        "%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name,
219
    })
220

221
    _tpl(
222
        repository_ctx,
223
        "build_defs.bzl",
224
        {
225
            "%{tf_cx11_abi}": tf_cx11_abi,
226
        },
227
    )
228
    return [PyInfo(transitive_sources = depset())]
229

230
tf_configure = repository_rule(
231
    environ = [
232
        _TF_HEADER_DIR,
233
        _TF_SHARED_LIBRARY_DIR,
234
        _TF_SHARED_LIBRARY_NAME,
235
        _TF_CXX11_ABI_FLAG,
236
    ],
237
    implementation = _tf_pip_impl,
238
)
239

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.