skypilot
1diff --git a/train.py b/train.py
2index 6e3b058..8c61ed4 100755
3--- a/train.py
4+++ b/train.py
5@@ -58,6 +58,9 @@ try:
6except ImportError:
7has_wandb = False
8
9+import sky_callback
10+from sky_callback import step_iterator
11+
12torch.backends.cudnn.benchmark = True
13_logger = logging.getLogger('train')
14
15@@ -609,6 +612,11 @@ def main():
16with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
17f.write(args_text)
18
19+ sky_callback.init(
20+ global_rank=args.rank,
21+ total_steps=num_epochs * len(loader_train),
22+ )
23+
24try:
25for epoch in range(start_epoch, num_epochs):
26if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
27@@ -674,7 +682,7 @@ def train_one_epoch(
28end = time.time()
29last_idx = len(loader) - 1
30num_updates = epoch * len(loader)
31- for batch_idx, (input, target) in enumerate(loader):
32+ for batch_idx, (input, target) in step_iterator(enumerate(loader)):
33last_batch = batch_idx == last_idx
34data_time_m.update(time.time() - end)
35if not args.prefetcher:
36