pytorch
/
pt_ops.bzl
672 строки · 19.9 Кб
1load("//tools/build_defs:expect.bzl", "expect")
2load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
3load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
4load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
5
6# @lint-ignore BUCKRESTRICTEDSYNTAX
7IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build
8
9USED_PT_BACKENDS = [
10"CPU",
11"QuantizedCPU",
12"SparseCPU", # brings ~20 kb size regression
13]
14
15def pt_operator_library(
16name,
17ops = [],
18exported_deps = [],
19check_decl = True,
20train = False,
21model = None,
22include_all_operators = False,
23include_base_operators = True,
24**kwargs):
25(model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information(
26name,
27model,
28)
29
30ops = [op.strip() for op in ops]
31
32# If ops are specified, then we are in static selective build mode, so we append
33# base ops to this list to avoid additional special case logic in subsequent code,
34# unless include_base_operators is explicitly set to False (the default is True)
35if len(ops) > 0 and include_base_operators:
36ops.extend(PT_BASE_OPS)
37
38labels = kwargs.pop("labels", [])
39visibility = kwargs.pop("visibility", ["PUBLIC"])
40
41# Sanity check the model name and versions. While the input to both is an array, the
42# codegen script only ever outputs a single item in the array so we can just assume that
43# here. If you ever need to depends on more than one assets, just break it up into a separate
44# BUCK targets.
45if model_assets or model_versions:
46if len(model_assets) != 1:
47fail("Model assets must be of size 1")
48if len(model_versions) != 1:
49fail("Model versions must be of size 1")
50
51# Is this a traced operator therefore has a YAML file with ops?
52yaml_option = ""
53if model_assets and len(model_assets) > 0:
54# We know these lists are only of length 1 via earlier assert.
55model_asset = model_assets[0]
56model_version = model_versions[0]
57
58# Pass the YAML file from this asset to the genrule below.
59yaml_dep = "{}_v{}_yaml".format(model_asset, model_version)
60fb_native.filegroup(
61name = yaml_dep,
62srcs = [
63model_asset + ".yaml",
64],
65# The visibility is not set to PUBLIC as this an internal detail. If you see this error
66# in your buck build flow, you are trying to use a hand-crafted "pt_operator_library" that
67# with parameters not supported outside of codegen targets!
68)
69
70# Since all selective traced ops are created by automation, we can assume they
71# have a YAML file at this very location. If it doesn't exist, it means the targets
72# was hand-crafted which is not a support workflow for traced ops.
73yaml_option = "--models_yaml_path $(location fbsource//xplat/pytorch_models/build/{}/v{}:{})/{}.yaml".format(model_name, model_version, yaml_dep, model_asset)
74
75not_include_all_overloads_static_root_ops = kwargs.pop(
76"not_include_all_overloads_static_root_ops",
77False,
78)
79
80not_include_all_overloads_closure_ops = kwargs.pop("not_include_all_overloads_closure_ops", False)
81
82if False:
83# TODO(nga): `yaml_option` is never `None`, but it is checked against `None` below.
84# Typechecker (`--unstable-typecheck`) catches it.
85yaml_option = None
86
87fb_xplat_genrule(
88name = name,
89out = "model_operators.yaml",
90cmd = (
91"$(exe {exe}) " +
92"{optionally_root_ops} " +
93"{optionally_training_root_ops} " +
94"--rule_name {rule_name} " +
95"--output_path \"${{OUT}}\" " +
96"--model_name {model_name} " +
97"--dep_graph_yaml_path {dep_graph_yaml} " +
98"{optionally_model_yamls} " +
99"{optionally_model_versions} " +
100"{optionally_model_assets} " +
101"{optionally_model_traced_backends} " +
102"{optionally_include_all_operators}" +
103"{not_include_all_overloads_static_root_ops}" +
104"{not_include_all_overloads_closure_ops}"
105).format(
106exe = "//tools:gen_operators_yaml" if IS_OSS else "fbsource//xplat/caffe2/tools:gen_operators_yaml",
107rule_name = name,
108model_name = model_name,
109dep_graph_yaml = "none" if IS_OSS else "$(location fbsource//xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ",
110optionally_model_yamls = "" if (IS_OSS or yaml_option == None) else yaml_option,
111optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "",
112optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "",
113optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "",
114optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "",
115optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "",
116optionally_include_all_operators = "--include_all_operators " if include_all_operators else "",
117not_include_all_overloads_static_root_ops = "--not_include_all_overloads_static_root_ops " if not_include_all_overloads_static_root_ops else "",
118not_include_all_overloads_closure_ops = "--not_include_all_overloads_closure_ops " if not_include_all_overloads_closure_ops else "",
119),
120labels = labels + [
121"pt_operator_library",
122"supermodule:android/default/pytorch",
123"supermodule:ios/default/public.pytorch",
124] + (["pt_train_operator_library"] if train else []),
125visibility = visibility,
126**kwargs
127)
128
129def validate_and_extract_model_information(name, model):
130model_name = name
131model_versions = None
132model_assets = None
133model_traced_backends = None
134
135if model != None:
136model_name = model.get("name")
137expect(model_name != None, "Expected Model Name to be present")
138model_versions = model.get("versions")
139expect(is_list(model_versions), "Expected model versions to be a list of string")
140for ver in model_versions or []:
141expect(is_string(ver), "Expected version '{}' to be string".format(str(ver)))
142model_assets = model.get("assets")
143expect(
144model_assets == None or is_list(model_assets),
145"Expected model assets to be a list of string if specified",
146)
147for asset_name in model_assets or []:
148expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name)))
149model_traced_backends = model.get("traced_backends")
150expect(
151model_traced_backends == None or is_list(model_traced_backends),
152"Expected model traced backends to be a list of string if specified",
153)
154
155if model_traced_backends != None:
156for backend in model_traced_backends:
157expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend)))
158expect(
159backend in USED_PT_BACKENDS,
160"Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)),
161)
162
163return (model_name, model_versions, model_assets, model_traced_backends)
164
165# This file keeps a list of PyTorch operators used by any targets in
166# @fbsource//xplat/...
167# The purpose of the list is to avoid generating large number of unused
168# operator registration code / BUCK rules at build time.
169# See more detail at: https://fb.quip.com/ZVh1AgOKW8Vv
170
171PT_OPS_PRIM = [
172"aten::str",
173"aten::list",
174"aten::__range_length",
175"aten::__derive_index",
176"prim::TupleUnpack",
177"prim::unchecked_cast",
178"aten::IntImplicit",
179"aten::FloatImplicit",
180"aten::ScalarImplicit",
181"aten::Bool.Tensor",
182"aten::Bool.int",
183"aten::Bool.float",
184"aten::Int.Tensor",
185"aten::Int.Scalar",
186"aten::Int.int",
187"aten::Int.bool",
188"aten::Int.str",
189"aten::Float.Tensor",
190"aten::Float.Scalar",
191"aten::Float.int",
192"aten::Float.bool",
193"aten::Float.str",
194"aten::format",
195"prim::NumToTensor.Scalar",
196"prim::RaiseException",
197"aten::Size",
198"aten::size",
199"prim::EnumName",
200"prim::EnumValue.int",
201"prim::EnumValue.float",
202"prim::EnumValue.str",
203"prim::TupleIndex",
204"aten::ne.int_list",
205"prim::unchecked_unwrap_optional",
206"prim::device",
207"prim::dtype",
208"aten::__not__",
209"aten::__is__",
210"aten::__isnot__",
211"aten::element_size",
212"aten::numel",
213"aten::dim",
214"aten::get_device",
215"aten::storage_offset",
216"aten::is_contiguous",
217"aten::select.t",
218"aten::__getitem__.t",
219"aten::append.t",
220"aten::reverse.t",
221"aten::extend.t",
222"aten::copy.t",
223"aten::_set_item.t",
224"aten::clear.t",
225"aten::Delete.t",
226"aten::insert.t",
227"aten::pop.t",
228"aten::add.t",
229"aten::add_.t",
230"aten::slice.t",
231"aten::list.t",
232"aten::mul.left_t",
233"aten::mul.right_",
234"aten::mul_.t",
235"aten::len.t",
236"aten::eq.int_list",
237"prim::Uninitialized",
238"prim::Print",
239"aten::eq.enum",
240"aten::ne.enum",
241"aten::dequantize.tensor",
242"aten::dequantize.any",
243"aten::add.str",
244"aten::eq.int",
245"aten::eq.float",
246"aten::eq.int_float",
247"aten::eq.float_int",
248"aten::eq",
249"aten::eq.str",
250"aten::ne.int",
251"aten::ne.float",
252"aten::ne.int_float",
253"aten::ne.float_int",
254"aten::ne",
255"aten::ne.str",
256"aten::lt.int",
257"aten::lt.float",
258"aten::lt.int_float",
259"aten::lt.float_int",
260"aten::lt",
261"aten::lt.str",
262"aten::gt.int",
263"aten::gt.float",
264"aten::gt.int_float",
265"aten::gt.float_int",
266"aten::gt",
267"aten::gt.str",
268"aten::le.int",
269"aten::le.float",
270"aten::le.int_float",
271"aten::le.float_int",
272"aten::le",
273"aten::le.str",
274"aten::ge.int",
275"aten::ge.float",
276"aten::ge.int_float",
277"aten::ge.float_int",
278"aten::ge",
279"aten::ge.str",
280"aten::add.int",
281"aten::add.float",
282"aten::add.int_float",
283"aten::add.float_int",
284"aten::add",
285"aten::sub.int",
286"aten::sub.float",
287"aten::sub.int_float",
288"aten::sub.float_int",
289"aten::sub",
290"aten::mul.int",
291"aten::mul.float",
292"aten::mul.int_float",
293"aten::mul.float_int",
294"aten::mul",
295"aten::__and__.bool",
296"aten::__or__.bool",
297"aten::__xor__.bool",
298"aten::floor.int",
299"aten::floor.float",
300"aten::floor.Scalar",
301"aten::ceil.int",
302"aten::ceil.float",
303"aten::ceil.Scalar",
304"aten::neg.int",
305"aten::neg.float",
306"aten::neg.Scalar",
307"aten::exp.int",
308"aten::exp.float",
309"aten::exp.Scalar",
310"aten::remainder.int",
311"aten::remainder.float",
312"aten::remainder.int_float",
313"aten::remainder.float_int",
314"aten::remainder",
315"aten::div.int",
316"aten::div.float",
317"aten::div",
318"aten::floordiv.int",
319"aten::floordiv.float",
320"aten::floordiv.int_float",
321"aten::floordiv.float_int",
322"aten::floordiv",
323"aten::pow.int",
324"aten::pow.float",
325"aten::pow.int_float",
326"aten::pow.float_int",
327"aten::pow.Scalar_Scalar",
328"aten::pow.int_to_int",
329"prim::min.int",
330"prim::min.float",
331"prim::min.int_float",
332"prim::min.float_int",
333"prim::min",
334"prim::max.int",
335"prim::max.float",
336"prim::max.int_float",
337"prim::max.float_int",
338"prim::max",
339"prim::type",
340"aten::len.Tensor",
341"aten::ord",
342"aten::lower",
343"aten::__contains__.str_list",
344"aten::len.str",
345"aten::__getitem__.str",
346"aten::copy_.Tensor",
347"aten::copy_.int",
348"aten::copy_.float",
349"aten::backward",
350"aten::index.Tensor_hacked_twin",
351"aten::_unsafe_index.Tensor_hacked_twin",
352"aten::_index_put_impl_.hacked_twin",
353"aten::index_put_.hacked_twin",
354"aten::index_put.hacked_twin",
355"aten::_unsafe_index_put.hacked_twin",
356"aten::to.prim_Device",
357"aten::to.prim_dtype",
358"prim::is_cuda",
359"prim::data",
360"prim::min.int_list",
361"prim::max.int_list",
362"prim::min.self_int",
363"prim::max.self_int",
364"prim::min.float_list",
365"prim::max.float_list",
366"prim::min.self_float",
367"prim::max.self_float",
368"prim::min.bool_list",
369"prim::max.bool_list",
370"prim::min.self_bool",
371"prim::max.self_bool",
372"aten::len.Dict_str",
373"aten::keys.str",
374"aten::values.str",
375"aten::__getitem__.Dict_str",
376"aten::get.str",
377"aten::get.default_str",
378"aten::setdefault.str",
379"aten::Delete.Dict_str",
380"aten::pop.Dict_str",
381"aten::pop.Dict_default_str",
382"aten::popitem.str",
383"aten::clear.str",
384"aten::update.str",
385"aten::items.str",
386"aten::copy.Dict_str",
387"aten::__contains__.str",
388"aten::_set_item.str",
389"aten::dict.str",
390"aten::len.Dict_int",
391"aten::keys.int",
392"aten::values.int",
393"aten::__getitem__.Dict_int",
394"aten::get.int",
395"aten::get.default_int",
396"aten::setdefault.int",
397"aten::Delete.Dict_int",
398"aten::pop.Dict_int",
399"aten::pop.Dict_default_int",
400"aten::popitem.int",
401"aten::clear.int",
402"aten::update.int",
403"aten::items.int",
404"aten::copy.Dict_int",
405"aten::__contains__.int",
406"aten::_set_item.int",
407"aten::dict.int",
408"aten::len.Dict_bool",
409"aten::keys.bool",
410"aten::values.bool",
411"aten::__getitem__.Dict_bool",
412"aten::get.bool",
413"aten::get.default_bool",
414"aten::setdefault.bool",
415"aten::Delete.Dict_bool",
416"aten::pop.Dict_bool",
417"aten::pop.Dict_default_bool",
418"aten::popitem.bool",
419"aten::clear.bool",
420"aten::update.bool",
421"aten::items.bool",
422"aten::copy.Dict_bool",
423"aten::__contains__.bool",
424"aten::_set_item.bool",
425"aten::dict.bool",
426"aten::len.Dict_float",
427"aten::keys.float",
428"aten::values.float",
429"aten::__getitem__.Dict_float",
430"aten::get.float",
431"aten::get.default_float",
432"aten::setdefault.float",
433"aten::Delete.Dict_float",
434"aten::pop.Dict_float",
435"aten::pop.Dict_default_float",
436"aten::popitem.float",
437"aten::clear.float",
438"aten::update.float",
439"aten::items.float",
440"aten::copy.Dict_float",
441"aten::__contains__.float",
442"aten::_set_item.float",
443"aten::dict.float",
444"aten::len.Dict_Tensor",
445"aten::keys.Tensor",
446"aten::values.Tensor",
447"aten::__getitem__.Dict_Tensor",
448"aten::get.Tensor",
449"aten::get.default_Tensor",
450"aten::setdefault.Tensor",
451"aten::Delete.Dict_Tensor",
452"aten::pop.Dict_Tensor",
453"aten::pop.Dict_default_Tensor",
454"aten::popitem.Tensor",
455"aten::clear.Tensor",
456"aten::update.Tensor",
457"aten::items.Tensor",
458"aten::copy.Dict_Tensor",
459"aten::__contains__.Tensor",
460"aten::_set_item.Tensor",
461"aten::dict.Tensor",
462"aten::__round_to_zero_floordiv.int",
463"aten::mathremainder.int",
464"aten::mathremainder.float",
465"aten::mathremainder.int_float",
466"aten::mathremainder.float_int",
467"aten::mathremainder",
468"aten::__and__.int",
469"aten::__or__.int",
470"aten::__xor__.int",
471"aten::__lshift__.int",
472"aten::__rshift__.int",
473"aten::round.int",
474"aten::round.float",
475"aten::round.Scalar",
476"aten::log.int",
477"aten::log.float",
478"aten::log.Scalar",
479"aten::log.int_int",
480"aten::log.float_float",
481"aten::log.int_float",
482"aten::log.float_int",
483"aten::log.Scalar_Scalar",
484"aten::log1p.int",
485"aten::log1p.float",
486"aten::log1p.Scalar",
487"aten::log10.int",
488"aten::log10.float",
489"aten::log10.Scalar",
490"aten::sqrt.int",
491"aten::sqrt.float",
492"aten::sqrt.Scalar",
493"aten::acos.int",
494"aten::acos.float",
495"aten::acos.Scalar",
496"aten::asin.int",
497"aten::asin.float",
498"aten::asin.Scalar",
499"aten::atan.int",
500"aten::atan.float",
501"aten::atan.Scalar",
502"aten::atan2.int",
503"aten::atan2.float",
504"aten::atan2.int_float",
505"aten::atan2.float_int",
506"aten::atan2.Scalar_Scalar",
507"aten::cos.int",
508"aten::cos.float",
509"aten::cos.Scalar",
510"aten::sin.int",
511"aten::sin.float",
512"aten::sin.Scalar",
513"aten::tan.int",
514"aten::tan.float",
515"aten::tan.Scalar",
516"aten::asinh.int",
517"aten::asinh.float",
518"aten::asinh.Scalar",
519"aten::atanh.int",
520"aten::atanh.float",
521"aten::atanh.Scalar",
522"aten::acosh.int",
523"aten::acosh.float",
524"aten::acosh.Scalar",
525"aten::sinh.int",
526"aten::sinh.float",
527"aten::sinh.Scalar",
528"aten::cosh.int",
529"aten::cosh.float",
530"aten::cosh.Scalar",
531"aten::tanh.int",
532"aten::tanh.float",
533"aten::tanh.Scalar",
534"aten::degrees.int",
535"aten::degrees.float",
536"aten::degrees.Scalar",
537"aten::radians.int",
538"aten::radians.float",
539"aten::radians.Scalar",
540"aten::fmod.int",
541"aten::fmod.float",
542"aten::fmod.int_float",
543"aten::fmod.float_int",
544"aten::fmod",
545"aten::factorial.int",
546"aten::isnan.float",
547"aten::isfinite.float",
548"aten::isinf.float",
549"aten::gamma.int",
550"aten::gamma.float",
551"aten::gamma.Scalar",
552"aten::erf.int",
553"aten::erf.float",
554"aten::erf.Scalar",
555"aten::erfc.int",
556"aten::erfc.float",
557"aten::erfc.Scalar",
558"aten::expm1.int",
559"aten::expm1.float",
560"aten::expm1.Scalar",
561"aten::fabs.int",
562"aten::fabs.float",
563"aten::fabs.Scalar",
564"aten::lgamma.int",
565"aten::lgamma.float",
566"aten::lgamma.Scalar",
567"prim::abs.int",
568"prim::abs.float",
569"prim::abs.Scalar",
570"aten::gcd.int",
571"aten::copysign.int",
572"aten::copysign.float",
573"aten::copysign.int_float",
574"aten::copysign.float_int",
575"aten::copysign",
576"aten::split",
577"aten::tensor.float",
578"aten::as_tensor.float",
579"aten::tensor.int",
580"aten::as_tensor.int",
581"aten::tensor.bool",
582"aten::as_tensor.bool",
583"aten::_infer_size",
584"aten::_no_grad_embedding_renorm_",
585"aten::tensor",
586"aten::as_tensor",
587"aten::as_tensor.list",
588"aten::_pack_sequence",
589"aten::_get_tracing_state",
590"aten::is_scripting",
591"aten::_no_grad_uniform_",
592"aten::_no_grad_normal_",
593"aten::_no_grad_fill_",
594"aten::_no_grad_zero_",
595]
596
597PT_BASE_OPS = [
598"aten::_coalesced_",
599"aten::_copy_from",
600"aten::_empty_affine_quantized",
601"aten::_empty_per_channel_affine_quantized",
602"aten::_indices",
603"aten::_nnz",
604"aten::_values",
605"aten::add",
606"aten::add_",
607"aten::arange",
608"aten::as_strided",
609"aten::as_strided_",
610"aten::cat",
611"aten::clone",
612"aten::coalesce",
613"aten::contiguous",
614"aten::copy_",
615"aten::copy_sparse_to_sparse_",
616"aten::dense_dim",
617"aten::dequantize",
618"aten::div",
619"aten::div_",
620"aten::empty",
621"aten::empty_like",
622"aten::empty_strided",
623"aten::eq",
624"aten::equal",
625"aten::expand",
626"aten::fill_",
627"aten::is_coalesced",
628"aten::is_complex",
629"aten::is_floating_point",
630"aten::is_leaf",
631"aten::is_nonzero",
632"aten::item",
633"aten::max",
634"aten::min",
635"aten::mul",
636"aten::mul_",
637"aten::narrow",
638"aten::ne",
639"aten::permute",
640"aten::q_per_channel_axis",
641"aten::q_per_channel_scales",
642"aten::q_per_channel_zero_points",
643"aten::q_scale",
644"aten::q_zero_point",
645"aten::qscheme",
646"aten::quantize_per_tensor",
647"aten::reshape",
648"aten::_reshape_alias",
649"aten::resize_",
650"aten::resize_as_",
651"aten::scalar_tensor",
652"aten::select",
653"aten::set_",
654"aten::size",
655"aten::slice",
656"aten::sparse_dim",
657"aten::sparse_resize_and_clear_",
658"aten::squeeze",
659"aten::squeeze_",
660"aten::stride",
661"aten::sub",
662"aten::sub_",
663"aten::sum",
664"aten::t",
665"aten::to",
666"aten::_to_copy",
667"aten::unsqueeze",
668"aten::view",
669"aten::zero_",
670"aten::zeros",
671"aten::zeros_like",
672]
673