pytorch
177 строк · 5.5 Кб
1import argparse
2import logging
3import os
4from functools import partial
5
6import torch
7import torch._dynamo as dynamo
8import torch.utils._pytree as pytree
9from torch._dynamo.testing import reduce_to_scalar_loss
10from torch.nn.parallel import DistributedDataParallel as DDP
11from torch.profiler import profile, ProfilerActivity, record_function
12
13
14try:
15from .common import timed
16from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
17except ImportError:
18from common import timed
19from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
20
21log = logging.getLogger(__name__)
22
23
24def torchviz_model(args, model, inputs, rank):
25from torchviz import make_dot
26
27outputs = model(*inputs)
28loss = reduce_to_scalar_loss(outputs)
29parameter_names = dict(model.named_parameters())
30dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
31if rank == 0:
32dot.render("torchviz.dot")
33
34
35def profile_model(args, model, inputs, rank):
36with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
37for i in range(args.repeat):
38with record_function("Forward"):
39outputs = model(*inputs)
40loss = reduce_to_scalar_loss(outputs)
41with record_function("Backward"):
42loss.backward()
43if rank == 0:
44prof.export_chrome_trace(args.trace_file)
45
46
47def run_model(args, model, inputs, key):
48rank = int(os.getenv("RANK", 0))
49world_size = int(os.getenv("WORLD_SIZE", 1))
50# result_q = []
51
52setup(rank, world_size)
53if args.device == "cuda":
54# needed for FSDP
55torch.cuda.set_device(rank)
56
57dev_rank = f"{args.device}:{rank}"
58model = model.to(dev_rank)
59
60def move_tensor(maybe_tensor):
61if torch.is_tensor(maybe_tensor):
62return maybe_tensor.to(dev_rank)
63return maybe_tensor
64
65inputs = pytree.tree_map(move_tensor, inputs)
66
67if args.fsdp:
68model = apply_fsdp(
69args,
70model,
71use_checkpointing=args.fsdp_checkpoint,
72use_wrap_policy=args.fsdp_wrap,
73)
74elif args.ddp:
75model = DDP(model)
76
77if args.verbose:
78print(model)
79
80if args.dynamo:
81dynamo.reset()
82if args.verbose:
83dynamo.config.verbose = True
84dynamo.config.log_level = logging.DEBUG
85if args.dynamo_no_optimize_ddp:
86dynamo.config.optimize_ddp = False
87if args.dynamo == "inductor" and args.fsdp:
88torch._inductor.config.triton.cudagraphs = False
89log.warning("disabling inductor cudagraphs for compatibility with FSDP")
90
91def print_compile(gm, ex):
92print(
93f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
94)
95return gm
96
97dynamo_ctx = dynamo.optimize(
98print_compile if args.dynamo == "print" else args.dynamo
99)
100model = dynamo_ctx(model)
101
102# warmup
103_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
104t_total = timed(
105model, model_iter_fn, inputs, times=args.repeat, return_result=False
106)
107if args.torchviz:
108torchviz_model(args, model, inputs, rank)
109if args.profile:
110profile_model(args, model, inputs, rank)
111
112cleanup()
113return t_total
114
115
116if __name__ == "__main__":
117parser = argparse.ArgumentParser()
118parser.add_argument("--device", default="cuda")
119parser.add_argument(
120"--dynamo",
121default=None,
122help="if set to a str, uses dynamo[str] backend. else, eager",
123)
124parser.add_argument("--verbose", action="store_true")
125parser.add_argument("--batch-size", "--batch_size", default=None)
126parser.add_argument(
127"--torchviz", action="store_true", help="Dump autograd graph with torchviz"
128)
129parser.add_argument("--profile", action="store_true", help="Run the profiler")
130parser.add_argument(
131"--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
132)
133parser.add_argument("--repeat", default=10, help="Repeats for timing run")
134parser.add_argument(
135"--dynamo-no-optimize-ddp",
136"--dynamo_no_optimize_ddp",
137action="store_true",
138help="Disable dynamo's ddp optimizer (enabled by default)",
139)
140parser.add_argument(
141"--fsdp-checkpoint",
142"--fsdp_checkpoint",
143action="store_true",
144help="Use gradient checkpointing via model-specific policy",
145)
146parser.add_argument(
147"--fsdp-wrap",
148"--fsdp_wrap",
149action="store_true",
150help="Apply fsdp to submodules via model-specific policy",
151)
152
153dist_arg = parser.add_mutually_exclusive_group()
154dist_arg.add_argument("--ddp", action="store_true")
155dist_arg.add_argument("--fsdp", action="store_true")
156
157model_arg = parser.add_mutually_exclusive_group(required=True)
158model_arg.add_argument(
159"--torchbench-model",
160"--torchbench_model",
161help="name of torchbench model, e.g. hf_Bert",
162)
163model_arg.add_argument(
164"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
165)
166args = parser.parse_args()
167
168model_name = args.torchbench_model
169if args.toy_model:
170model_name = "ToyModel"
171model, inputs = get_model(args)
172
173fn = partial(run_model, args, model, inputs)
174
175world_size = os.getenv("WORLD_SIZE", 1)
176t_total = fn(f"{model_name}_{world_size}")
177print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
178