skypilot

Форк
0
/
callback.patch 
35 строк · 1.1 Кб
1
diff --git a/train.py b/train.py
2
index 6e3b058..8c61ed4 100755
3
--- a/train.py
4
+++ b/train.py
5
@@ -58,6 +58,9 @@ try:
6
 except ImportError: 
7
     has_wandb = False
8
 
9
+import sky_callback
10
+from sky_callback import step_iterator
11
+
12
 torch.backends.cudnn.benchmark = True
13
 _logger = logging.getLogger('train')
14
 
15
@@ -609,6 +612,11 @@ def main():
16
         with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
17
             f.write(args_text)
18
 
19
+    sky_callback.init(
20
+        global_rank=args.rank,
21
+        total_steps=num_epochs * len(loader_train),
22
+    )
23
+
24
     try:
25
         for epoch in range(start_epoch, num_epochs):
26
             if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
27
@@ -674,7 +682,7 @@ def train_one_epoch(
28
     end = time.time()
29
     last_idx = len(loader) - 1
30
     num_updates = epoch * len(loader)
31
-    for batch_idx, (input, target) in enumerate(loader):
32
+    for batch_idx, (input, target) in step_iterator(enumerate(loader)):
33
         last_batch = batch_idx == last_idx
34
         data_time_m.update(time.time() - end)
35
         if not args.prefetcher:
36

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.