lama
31 строка · 972.0 Байт
1import logging
2
3from saicinpainting.training.modules.ffc import FFCResNetGenerator
4from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
5NLayerDiscriminator, MultidilatedNLayerDiscriminator
6
7def make_generator(config, kind, **kwargs):
8logging.info(f'Make generator {kind}')
9
10if kind == 'pix2pixhd_multidilated':
11return MultiDilatedGlobalGenerator(**kwargs)
12
13if kind == 'pix2pixhd_global':
14return GlobalGenerator(**kwargs)
15
16if kind == 'ffc_resnet':
17return FFCResNetGenerator(**kwargs)
18
19raise ValueError(f'Unknown generator kind {kind}')
20
21
22def make_discriminator(kind, **kwargs):
23logging.info(f'Make discriminator {kind}')
24
25if kind == 'pix2pixhd_nlayer_multidilated':
26return MultidilatedNLayerDiscriminator(**kwargs)
27
28if kind == 'pix2pixhd_nlayer':
29return NLayerDiscriminator(**kwargs)
30
31raise ValueError(f'Unknown discriminator kind {kind}')
32