google-research

Форк
0
71 строка · 2.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""miscellaneous utils."""
17
from collections import defaultdict  # pylint: disable=g-importing-member
18
import json
19
import os
20

21
import torch
22

23

24
def find_best_checkpoint(
25
    ckpt_dir, start_from=None, end_on=None, metric_name='fid5k_full'
26
):
27
  """find checkpoint with best metric value and return path."""
28
  # based on stylegan training-runs outputs
29
  metric_file = os.path.join(ckpt_dir, f'metric-{metric_name}.jsonl')
30
  fids = []
31
  with open(metric_file) as f:
32
    for line in f:
33
      fids.append((json.loads(line.strip())))
34
  metric = []
35
  for item in fids:
36
    metric.append((item['results'][metric_name], item['snapshot_pkl']))
37
  if start_from is not None:
38
    metric = metric[start_from:]
39
  if end_on is not None:
40
    metric = metric[:end_on]
41
  ckpt_metric = min(metric)
42
  print('best checkpoint:')
43
  print(ckpt_metric)
44
  ckpt_path = os.path.join(ckpt_dir, ckpt_metric[1])
45
  print(ckpt_path)
46
  print('final checkpoint: %s' % metric[-1][1])
47
  print('final checkpoint idx: %s' % len(metric))
48
  return ckpt_path
49

50

51
def interpolate(x, size, mode='bilinear'):
52
  out = torch.nn.functional.interpolate(
53
      x, size, mode=mode, align_corners=False, antialias=True
54
  )
55
  return out
56

57

58
def concat_dict(input_list, dim=1):
59
  # input: list of dictionaries
60
  # output: dictionary with values concatenated from input list
61
  output_dict = defaultdict(list)
62
  for item in input_list:
63
    for k, v in item.items():
64
      output_dict[k].append(v)
65
  return {k: torch.cat(v, dim=dim) for k, v in output_dict.items()}
66

67

68
def count_parameters(model, all_params=False):
69
  return sum(
70
      p.numel() for p in model.parameters() if p.requires_grad or all_params
71
  )
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.