pytorch
/
pt_ops.bzl
671 строка · 19.9 Кб
1load("//tools/build_defs:expect.bzl", "expect")
2load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
3load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
4load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
5
6IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
7
8USED_PT_BACKENDS = [
9"CPU",
10"QuantizedCPU",
11"SparseCPU", # brings ~20 kb size regression
12]
13
14def pt_operator_library(
15name,
16ops = [],
17exported_deps = [],
18check_decl = True,
19train = False,
20model = None,
21include_all_operators = False,
22include_base_operators = True,
23**kwargs):
24(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
25name,
26model,
27)
28
29ops = [op.strip() for op in ops]
30
31# If ops are specified, then we are in static selective build mode, so we append
32# base ops to this list to avoid additional special case logic in subsequent code,
33# unless include_base_operators is explicitly set to False (the default is True)
34if len(ops) > 0 and include_base_operators:
35ops.extend(PT_BASE_OPS)
36
37labels = kwargs.pop("labels", [])
38visibility = kwargs.pop("visibility", ["PUBLIC"])
39
40# Sanity check the model name and versions. While the input to both is an array, the
41# codegen script only ever outputs a single item in the array so we can just assume that
42# here. If you ever need to depends on more than one assets, just break it up into a separate
43# BUCK targets.
44if model_assets or model_versions:
45if len(model_assets) != 1:
46fail("Model assets must be of size 1")
47if len(model_versions) != 1:
48fail("Model versions must be of size 1")
49
50# Is this a traced operator therefore has a YAML file with ops?
51yaml_option = ""
52if model_assets and len(model_assets) > 0:
53# We know these lists are only of length 1 via earlier assert.
54model_asset = model_assets[0]
55model_version = model_versions[0]
56
57# Pass the YAML file from this asset to the genrule below.
58yaml_dep = "{}_v{}_yaml".format(model_asset, model_version)
59fb_native.filegroup(
60name = yaml_dep,
61srcs = [
62model_asset + ".yaml",
63],
64# The visibility is not set to PUBLIC as this an internal detail. If you see this error
65# in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that
66# with parameters not supported outside of codegen targets!
67)
68
69# Since all selective traced ops are created by automation, we can assume they
70# have a YAML file at this very location. If it doesn't exist, it means the targets
71# was hand-crafted which is not a support workflow for traced ops.
72yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
73
74not_include_all_overloads_static_root_ops = kwargs.pop(
75"not_include_all_overloads_static_root_ops",
76False,
77)
78
79not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
80
81if False:
82# TODO(nga): `yaml_option` is never `None`, but it is checked against `None` below.
83# Typechecker (`--unstable-typecheck`) catches it.
84yaml_option = None
85
86fb_xplat_genrule(
87name = name,
88out = "model_operators.yaml",
89cmd = (
90"$(exe {exe}) " +
91"{optionally_root_ops} " +
92"{optionally_training_root_ops} " +
93"--rule_name {rule_name} " +
94"--output_path \"${{OUT}}\" " +
95"--model_name {model_name} " +
96"--dep_graph_yaml_path {dep_graph_yaml} " +
97"{optionally_model_yamls} " +
98"{optionally_model_versions} " +
99"{optionally_model_assets} " +
100"{optionally_model_traced_backends} " +
101"{optionally_include_all_operators}" +
102"{not_include_all_overloads_static_root_ops}" +
103"{not_include_all_overloads_closure_ops}"
104).format(
105exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
106rule_name = name,
107model_name = model_name,
108dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
109optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option,
110optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
111optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
112optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
113optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
114optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
115optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
116not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
117not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
118),
119labels = labels + [
120"pt_operator_library",
121"supermodule:android/default/pytorch",
122"supermodule:ios/default/public.pytorch",
123] + (["pt_train_operator_library"] if train else []),
124visibility = visibility,
125**kwargs
126)
127
128def validate_and_extract_model_information(name, model):
129model_name = name
130model_versions = None
131model_assets = None
132model_traced_backends = None
133
134if model != None:
135model_name = model.get("name")
136expect(model_name != None, "Expected Model Name to be present")
137model_versions = model.get("versions")
138expect(is_list(model_versions), "Expected model versions to be a list of string")
139for ver in model_versions or []:
140expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
141model_assets = model.get("assets")
142expect(
143model_assets == None or is_list(model_assets),
144"Expected model assets to be a list of string if specified",
145)
146for asset_name in model_assets or []:
147expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
148model_traced_backends = model.get("traced_backends")
149expect(
150model_traced_backends == None or is_list(model_traced_backends),
151"Expected model traced backends to be a list of string if specified",
152)
153
154if model_traced_backends != None:
155for backend in model_traced_backends:
156expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
157expect(
158backend in USED_PT_BACKENDS,
159"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
160)
161
162return (model_name, model_versions, model_assets, model_traced_backends)
163
164# This file keeps a list of PyTorch operators used by any targets in
165# @fbsource//xplat/...
166# The purpose of the list is to avoid generating large number of unused
167# operator registration code / BUCK rules at build time.
168# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv
169
170PT_OPS_PRIM = [
171"aten::str",
172"aten::list",
173"aten::__range_length",
174"aten::__derive_index",
175"prim::TupleUnpack",
176"prim::unchecked_cast",
177"aten::IntImplicit",
178"aten::FloatImplicit",
179"aten::ScalarImplicit",
180"aten::Bool.Tensor",
181"aten::Bool.int",
182"aten::Bool.float",
183"aten::Int.Tensor",
184"aten::Int.Scalar",
185"aten::Int.int",
186"aten::Int.bool",
187"aten::Int.str",
188"aten::Float.Tensor",
189"aten::Float.Scalar",
190"aten::Float.int",
191"aten::Float.bool",
192"aten::Float.str",
193"aten::format",
194"prim::NumToTensor.Scalar",
195"prim::RaiseException",
196"aten::Size",
197"aten::size",
198"prim::EnumName",
199"prim::EnumValue.int",
200"prim::EnumValue.float",
201"prim::EnumValue.str",
202"prim::TupleIndex",
203"aten::ne.int_list",
204"prim::unchecked_unwrap_optional",
205"prim::device",
206"prim::dtype",
207"aten::__not__",
208"aten::__is__",
209"aten::__isnot__",
210"aten::element_size",
211"aten::numel",
212"aten::dim",
213"aten::get_device",
214"aten::storage_offset",
215"aten::is_contiguous",
216"aten::select.t",
217"aten::__getitem__.t",
218"aten::append.t",
219"aten::reverse.t",
220"aten::extend.t",
221"aten::copy.t",
222"aten::_set_item.t",
223"aten::clear.t",
224"aten::Delete.t",
225"aten::insert.t",
226"aten::pop.t",
227"aten::add.t",
228"aten::add_.t",
229"aten::slice.t",
230"aten::list.t",
231"aten::mul.left_t",
232"aten::mul.right_",
233"aten::mul_.t",
234"aten::len.t",
235"aten::eq.int_list",
236"prim::Uninitialized",
237"prim::Print",
238"aten::eq.enum",
239"aten::ne.enum",
240"aten::dequantize.tensor",
241"aten::dequantize.any",
242"aten::add.str",
243"aten::eq.int",
244"aten::eq.float",
245"aten::eq.int_float",
246"aten::eq.float_int",
247"aten::eq",
248"aten::eq.str",
249"aten::ne.int",
250"aten::ne.float",
251"aten::ne.int_float",
252"aten::ne.float_int",
253"aten::ne",
254"aten::ne.str",
255"aten::lt.int",
256"aten::lt.float",
257"aten::lt.int_float",
258"aten::lt.float_int",
259"aten::lt",
260"aten::lt.str",
261"aten::gt.int",
262"aten::gt.float",
263"aten::gt.int_float",
264"aten::gt.float_int",
265"aten::gt",
266"aten::gt.str",
267"aten::le.int",
268"aten::le.float",
269"aten::le.int_float",
270"aten::le.float_int",
271"aten::le",
272"aten::le.str",
273"aten::ge.int",
274"aten::ge.float",
275"aten::ge.int_float",
276"aten::ge.float_int",
277"aten::ge",
278"aten::ge.str",
279"aten::add.int",
280"aten::add.float",
281"aten::add.int_float",
282"aten::add.float_int",
283"aten::add",
284"aten::sub.int",
285"aten::sub.float",
286"aten::sub.int_float",
287"aten::sub.float_int",
288"aten::sub",
289"aten::mul.int",
290"aten::mul.float",
291"aten::mul.int_float",
292"aten::mul.float_int",
293"aten::mul",
294"aten::__and__.bool",
295"aten::__or__.bool",
296"aten::__xor__.bool",
297"aten::floor.int",
298"aten::floor.float",
299"aten::floor.Scalar",
300"aten::ceil.int",
301"aten::ceil.float",
302"aten::ceil.Scalar",
303"aten::neg.int",
304"aten::neg.float",
305"aten::neg.Scalar",
306"aten::exp.int",
307"aten::exp.float",
308"aten::exp.Scalar",
309"aten::remainder.int",
310"aten::remainder.float",
311"aten::remainder.int_float",
312"aten::remainder.float_int",
313"aten::remainder",
314"aten::div.int",
315"aten::div.float",
316"aten::div",
317"aten::floordiv.int",
318"aten::floordiv.float",
319"aten::floordiv.int_float",
320"aten::floordiv.float_int",
321"aten::floordiv",
322"aten::pow.int",
323"aten::pow.float",
324"aten::pow.int_float",
325"aten::pow.float_int",
326"aten::pow.Scalar_Scalar",
327"aten::pow.int_to_int",
328"prim::min.int",
329"prim::min.float",
330"prim::min.int_float",
331"prim::min.float_int",
332"prim::min",
333"prim::max.int",
334"prim::max.float",
335"prim::max.int_float",
336"prim::max.float_int",
337"prim::max",
338"prim::type",
339"aten::len.Tensor",
340"aten::ord",
341"aten::lower",
342"aten::__contains__.str_list",
343"aten::len.str",
344"aten::__getitem__.str",
345"aten::copy_.Tensor",
346"aten::copy_.int",
347"aten::copy_.float",
348"aten::backward",
349"aten::index.Tensor_hacked_twin",
350"aten::_unsafe_index.Tensor_hacked_twin",
351"aten::_index_put_impl_.hacked_twin",
352"aten::index_put_.hacked_twin",
353"aten::index_put.hacked_twin",
354"aten::_unsafe_index_put.hacked_twin",
355"aten::to.prim_Device",
356"aten::to.prim_dtype",
357"prim::is_cuda",
358"prim::data",
359"prim::min.int_list",
360"prim::max.int_list",
361"prim::min.self_int",
362"prim::max.self_int",
363"prim::min.float_list",
364"prim::max.float_list",
365"prim::min.self_float",
366"prim::max.self_float",
367"prim::min.bool_list",
368"prim::max.bool_list",
369"prim::min.self_bool",
370"prim::max.self_bool",
371"aten::len.Dict_str",
372"aten::keys.str",
373"aten::values.str",
374"aten::__getitem__.Dict_str",
375"aten::get.str",
376"aten::get.default_str",
377"aten::setdefault.str",
378"aten::Delete.Dict_str",
379"aten::pop.Dict_str",
380"aten::pop.Dict_default_str",
381"aten::popitem.str",
382"aten::clear.str",
383"aten::update.str",
384"aten::items.str",
385"aten::copy.Dict_str",
386"aten::__contains__.str",
387"aten::_set_item.str",
388"aten::dict.str",
389"aten::len.Dict_int",
390"aten::keys.int",
391"aten::values.int",
392"aten::__getitem__.Dict_int",
393"aten::get.int",
394"aten::get.default_int",
395"aten::setdefault.int",
396"aten::Delete.Dict_int",
397"aten::pop.Dict_int",
398"aten::pop.Dict_default_int",
399"aten::popitem.int",
400"aten::clear.int",
401"aten::update.int",
402"aten::items.int",
403"aten::copy.Dict_int",
404"aten::__contains__.int",
405"aten::_set_item.int",
406"aten::dict.int",
407"aten::len.Dict_bool",
408"aten::keys.bool",
409"aten::values.bool",
410"aten::__getitem__.Dict_bool",
411"aten::get.bool",
412"aten::get.default_bool",
413"aten::setdefault.bool",
414"aten::Delete.Dict_bool",
415"aten::pop.Dict_bool",
416"aten::pop.Dict_default_bool",
417"aten::popitem.bool",
418"aten::clear.bool",
419"aten::update.bool",
420"aten::items.bool",
421"aten::copy.Dict_bool",
422"aten::__contains__.bool",
423"aten::_set_item.bool",
424"aten::dict.bool",
425"aten::len.Dict_float",
426"aten::keys.float",
427"aten::values.float",
428"aten::__getitem__.Dict_float",
429"aten::get.float",
430"aten::get.default_float",
431"aten::setdefault.float",
432"aten::Delete.Dict_float",
433"aten::pop.Dict_float",
434"aten::pop.Dict_default_float",
435"aten::popitem.float",
436"aten::clear.float",
437"aten::update.float",
438"aten::items.float",
439"aten::copy.Dict_float",
440"aten::__contains__.float",
441"aten::_set_item.float",
442"aten::dict.float",
443"aten::len.Dict_Tensor",
444"aten::keys.Tensor",
445"aten::values.Tensor",
446"aten::__getitem__.Dict_Tensor",
447"aten::get.Tensor",
448"aten::get.default_Tensor",
449"aten::setdefault.Tensor",
450"aten::Delete.Dict_Tensor",
451"aten::pop.Dict_Tensor",
452"aten::pop.Dict_default_Tensor",
453"aten::popitem.Tensor",
454"aten::clear.Tensor",
455"aten::update.Tensor",
456"aten::items.Tensor",
457"aten::copy.Dict_Tensor",
458"aten::__contains__.Tensor",
459"aten::_set_item.Tensor",
460"aten::dict.Tensor",
461"aten::__round_to_zero_floordiv.int",
462"aten::mathremainder.int",
463"aten::mathremainder.float",
464"aten::mathremainder.int_float",
465"aten::mathremainder.float_int",
466"aten::mathremainder",
467"aten::__and__.int",
468"aten::__or__.int",
469"aten::__xor__.int",
470"aten::__lshift__.int",
471"aten::__rshift__.int",
472"aten::round.int",
473"aten::round.float",
474"aten::round.Scalar",
475"aten::log.int",
476"aten::log.float",
477"aten::log.Scalar",
478"aten::log.int_int",
479"aten::log.float_float",
480"aten::log.int_float",
481"aten::log.float_int",
482"aten::log.Scalar_Scalar",
483"aten::log1p.int",
484"aten::log1p.float",
485"aten::log1p.Scalar",
486"aten::log10.int",
487"aten::log10.float",
488"aten::log10.Scalar",
489"aten::sqrt.int",
490"aten::sqrt.float",
491"aten::sqrt.Scalar",
492"aten::acos.int",
493"aten::acos.float",
494"aten::acos.Scalar",
495"aten::asin.int",
496"aten::asin.float",
497"aten::asin.Scalar",
498"aten::atan.int",
499"aten::atan.float",
500"aten::atan.Scalar",
501"aten::atan2.int",
502"aten::atan2.float",
503"aten::atan2.int_float",
504"aten::atan2.float_int",
505"aten::atan2.Scalar_Scalar",
506"aten::cos.int",
507"aten::cos.float",
508"aten::cos.Scalar",
509"aten::sin.int",
510"aten::sin.float",
511"aten::sin.Scalar",
512"aten::tan.int",
513"aten::tan.float",
514"aten::tan.Scalar",
515"aten::asinh.int",
516"aten::asinh.float",
517"aten::asinh.Scalar",
518"aten::atanh.int",
519"aten::atanh.float",
520"aten::atanh.Scalar",
521"aten::acosh.int",
522"aten::acosh.float",
523"aten::acosh.Scalar",
524"aten::sinh.int",
525"aten::sinh.float",
526"aten::sinh.Scalar",
527"aten::cosh.int",
528"aten::cosh.float",
529"aten::cosh.Scalar",
530"aten::tanh.int",
531"aten::tanh.float",
532"aten::tanh.Scalar",
533"aten::degrees.int",
534"aten::degrees.float",
535"aten::degrees.Scalar",
536"aten::radians.int",
537"aten::radians.float",
538"aten::radians.Scalar",
539"aten::fmod.int",
540"aten::fmod.float",
541"aten::fmod.int_float",
542"aten::fmod.float_int",
543"aten::fmod",
544"aten::factorial.int",
545"aten::isnan.float",
546"aten::isfinite.float",
547"aten::isinf.float",
548"aten::gamma.int",
549"aten::gamma.float",
550"aten::gamma.Scalar",
551"aten::erf.int",
552"aten::erf.float",
553"aten::erf.Scalar",
554"aten::erfc.int",
555"aten::erfc.float",
556"aten::erfc.Scalar",
557"aten::expm1.int",
558"aten::expm1.float",
559"aten::expm1.Scalar",
560"aten::fabs.int",
561"aten::fabs.float",
562"aten::fabs.Scalar",
563"aten::lgamma.int",
564"aten::lgamma.float",
565"aten::lgamma.Scalar",
566"prim::abs.int",
567"prim::abs.float",
568"prim::abs.Scalar",
569"aten::gcd.int",
570"aten::copysign.int",
571"aten::copysign.float",
572"aten::copysign.int_float",
573"aten::copysign.float_int",
574"aten::copysign",
575"aten::split",
576"aten::tensor.float",
577"aten::as_tensor.float",
578"aten::tensor.int",
579"aten::as_tensor.int",
580"aten::tensor.bool",
581"aten::as_tensor.bool",
582"aten::_infer_size",
583"aten::_no_grad_embedding_renorm_",
584"aten::tensor",
585"aten::as_tensor",
586"aten::as_tensor.list",
587"aten::_pack_sequence",
588"aten::_get_tracing_state",
589"aten::is_scripting",
590"aten::_no_grad_uniform_",
591"aten::_no_grad_normal_",
592"aten::_no_grad_fill_",
593"aten::_no_grad_zero_",
594]
595
596PT_BASE_OPS = [
597"aten::_coalesced_",
598"aten::_copy_from",
599"aten::_empty_affine_quantized",
600"aten::_empty_per_channel_affine_quantized",
601"aten::_indices",
602"aten::_nnz",
603"aten::_values",
604"aten::add",
605"aten::add_",
606"aten::arange",
607"aten::as_strided",
608"aten::as_strided_",
609"aten::cat",
610"aten::clone",
611"aten::coalesce",
612"aten::contiguous",
613"aten::copy_",
614"aten::copy_sparse_to_sparse_",
615"aten::dense_dim",
616"aten::dequantize",
617"aten::div",
618"aten::div_",
619"aten::empty",
620"aten::empty_like",
621"aten::empty_strided",
622"aten::eq",
623"aten::equal",
624"aten::expand",
625"aten::fill_",
626"aten::is_coalesced",
627"aten::is_complex",
628"aten::is_floating_point",
629"aten::is_leaf",
630"aten::is_nonzero",
631"aten::item",
632"aten::max",
633"aten::min",
634"aten::mul",
635"aten::mul_",
636"aten::narrow",
637"aten::ne",
638"aten::permute",
639"aten::q_per_channel_axis",
640"aten::q_per_channel_scales",
641"aten::q_per_channel_zero_points",
642"aten::q_scale",
643"aten::q_zero_point",
644"aten::qscheme",
645"aten::quantize_per_tensor",
646"aten::reshape",
647"aten::_reshape_alias",
648"aten::resize_",
649"aten::resize_as_",
650"aten::scalar_tensor",
651"aten::select",
652"aten::set_",
653"aten::size",
654"aten::slice",
655"aten::sparse_dim",
656"aten::sparse_resize_and_clear_",
657"aten::squeeze",
658"aten::squeeze_",
659"aten::stride",
660"aten::sub",
661"aten::sub_",
662"aten::sum",
663"aten::t",
664"aten::to",
665"aten::_to_copy",
666"aten::unsqueeze",
667"aten::view",
668"aten::zero_",
669"aten::zeros",
670"aten::zeros_like",
671]
672