google-research
238 строк · 7.7 Кб
1"""Setup TensorFlow as external dependency"""
2
3_TF_HEADER_DIR = "TF_HEADER_DIR"
4_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
5_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
6_TF_CXX11_ABI_FLAG = "TF_CXX11_ABI_FLAG"
7
8def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
9if not out:
10out = tpl
11repository_ctx.template(
12out,
13Label("//build_deps/tf_dependency:%s.tpl" % tpl),
14substitutions,
15)
16
17def _fail(msg):
18"""Output failure message when auto configuration fails."""
19red = "\033[0;31m"
20no_color = "\033[0m"
21fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
22
23def _is_windows(repository_ctx):
24"""Returns true if the host operating system is windows."""
25os_name = repository_ctx.os.name.lower()
26if os_name.find("windows") != -1:
27return True
28return False
29
30def _execute(
31repository_ctx,
32cmdline,
33error_msg = None,
34error_details = None,
35empty_stdout_fine = False):
36"""Executes an arbitrary shell command.
37
38Helper for executes an arbitrary shell command.
39
40Args:
41repository_ctx: the repository_ctx object.
42cmdline: list of strings, the command to execute.
43error_msg: string, a summary of the error if the command fails.
44error_details: string, details about the error or steps to fix it.
45empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
46it's an error.
47
48Returns:
49The result of repository_ctx.execute(cmdline).
50"""
51result = repository_ctx.execute(cmdline)
52if result.stderr or not (empty_stdout_fine or result.stdout):
53_fail("\n".join([
54error_msg.strip() if error_msg else "Repository command failed",
55result.stderr.strip(),
56error_details if error_details else "",
57]))
58return result
59
60def _read_dir(repository_ctx, src_dir):
61"""Returns a string with all files in a directory.
62
63Finds all files inside a directory, traversing subfolders and following
64symlinks. The returned string contains the full path of all files
65separated by line breaks.
66
67Args:
68repository_ctx: the repository_ctx object.
69src_dir: directory to find files from.
70
71Returns:
72A string of all files inside the given dir.
73"""
74if _is_windows(repository_ctx):
75src_dir = src_dir.replace("/", "\\")
76find_result = _execute(
77repository_ctx,
78["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
79empty_stdout_fine = True,
80)
81
82# src_files will be used in genrule.outs where the paths must
83# use forward slashes.
84result = find_result.stdout.replace("\\", "/")
85else:
86find_result = _execute(
87repository_ctx,
88["find", src_dir, "-follow", "-type", "f"],
89empty_stdout_fine = True,
90)
91result = find_result.stdout
92return result
93
94def _genrule(genrule_name, command, outs):
95"""Returns a string with a genrule.
96
97Genrule executes the given command and produces the given outputs.
98
99Args:
100genrule_name: A unique name for genrule target.
101command: The command to run.
102outs: A list of files generated by this rule.
103
104Returns:
105A genrule target.
106"""
107return (
108"genrule(\n" +
109' name = "' +
110genrule_name + '",\n' +
111" outs = [\n" +
112outs +
113"\n ],\n" +
114' cmd = """\n' +
115command +
116'\n """,\n' +
117")\n"
118)
119
120def _norm_path(path):
121"""Returns a path with '/' and remove the trailing slash."""
122path = path.replace("\\", "/")
123if path[-1] == "/":
124path = path[:-1]
125return path
126
127def _symlink_genrule_for_dir(
128repository_ctx,
129src_dir,
130dest_dir,
131genrule_name,
132src_files = [],
133dest_files = [],
134tf_pip_dir_rename_pair = []):
135"""Returns a genrule to symlink(or copy if on Windows) a set of files.
136
137If src_dir is passed, files will be read from the given directory; otherwise
138we assume files are in src_files and dest_files.
139Args:
140repository_ctx: the repository_ctx object.
141src_dir: source directory.
142dest_dir: directory to create symlink in.
143genrule_name: genrule name.
144src_files: list of source files instead of src_dir.
145dest_files: list of corresonding destination files.
146tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to
147replace. For example, in TF pip package, the source code is under
148"tensorflow_core", and we might want to replace it with
149"tensorflow" to match the header includes.
150Returns:
151genrule target that creates the symlinks.
152"""
153
154# Check that tf_pip_dir_rename_pair has the right length
155tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair)
156if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2:
157_fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len)
158
159if src_dir != None:
160src_dir = _norm_path(src_dir)
161dest_dir = _norm_path(dest_dir)
162files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
163
164# Create a list with the src_dir stripped to use for outputs.
165if tf_pip_dir_rename_pair_len:
166dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines()
167else:
168dest_files = files.replace(src_dir, "").splitlines()
169src_files = files.splitlines()
170command = []
171outs = []
172
173for i in range(len(dest_files)):
174if dest_files[i] != "":
175# If we have only one file to link we do not want to use the dest_dir, as
176# $(@D) will include the full path to the file.
177dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
178
179# Copy the headers to create a sandboxable setup.
180cmd = "cp -f"
181command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
182outs.append(' "' + dest_dir + dest_files[i] + '",')
183
184genrule = _genrule(
185genrule_name,
186" && ".join(command),
187"\n".join(outs),
188)
189return genrule
190
191def _tf_pip_impl(repository_ctx):
192tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
193tf_header_rule = _symlink_genrule_for_dir(
194repository_ctx,
195tf_header_dir,
196"include",
197"tf_header_include",
198tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"],
199)
200
201tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
202tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
203tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
204tf_cx11_abi = "-D_GLIBCXX_USE_CXX11_ABI=%s" % (repository_ctx.os.environ[_TF_CXX11_ABI_FLAG])
205
206tf_shared_library_rule = _symlink_genrule_for_dir(
207repository_ctx,
208None,
209"",
210tf_shared_library_name,
211[tf_shared_library_path],
212[tf_shared_library_name],
213)
214
215_tpl(repository_ctx, "BUILD", {
216"%{TF_HEADER_GENRULE}": tf_header_rule,
217"%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
218"%{TF_SHARED_LIBRARY_NAME}": tf_shared_library_name,
219})
220
221_tpl(
222repository_ctx,
223"build_defs.bzl",
224{
225"%{tf_cx11_abi}": tf_cx11_abi,
226},
227)
228return [PyInfo(transitive_sources = depset())]
229
230tf_configure = repository_rule(
231environ = [
232_TF_HEADER_DIR,
233_TF_SHARED_LIBRARY_DIR,
234_TF_SHARED_LIBRARY_NAME,
235_TF_CXX11_ABI_FLAG,
236],
237implementation = _tf_pip_impl,
238)
239