google-research
71 строка · 2.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""miscellaneous utils."""
17from collections import defaultdict # pylint: disable=g-importing-member
18import json
19import os
20
21import torch
22
23
24def find_best_checkpoint(
25ckpt_dir, start_from=None, end_on=None, metric_name='fid5k_full'
26):
27"""find checkpoint with best metric value and return path."""
28# based on stylegan training-runs outputs
29metric_file = os.path.join(ckpt_dir, f'metric-{metric_name}.jsonl')
30fids = []
31with open(metric_file) as f:
32for line in f:
33fids.append((json.loads(line.strip())))
34metric = []
35for item in fids:
36metric.append((item['results'][metric_name], item['snapshot_pkl']))
37if start_from is not None:
38metric = metric[start_from:]
39if end_on is not None:
40metric = metric[:end_on]
41ckpt_metric = min(metric)
42print('best checkpoint:')
43print(ckpt_metric)
44ckpt_path = os.path.join(ckpt_dir, ckpt_metric[1])
45print(ckpt_path)
46print('final checkpoint: %s' % metric[-1][1])
47print('final checkpoint idx: %s' % len(metric))
48return ckpt_path
49
50
51def interpolate(x, size, mode='bilinear'):
52out = torch.nn.functional.interpolate(
53x, size, mode=mode, align_corners=False, antialias=True
54)
55return out
56
57
58def concat_dict(input_list, dim=1):
59# input: list of dictionaries
60# output: dictionary with values concatenated from input list
61output_dict = defaultdict(list)
62for item in input_list:
63for k, v in item.items():
64output_dict[k].append(v)
65return {k: torch.cat(v, dim=dim) for k, v in output_dict.items()}
66
67
68def count_parameters(model, all_params=False):
69return sum(
70p.numel() for p in model.parameters() if p.requires_grad or all_params
71)
72