lama

Форк
0
47 строк · 2.1 Кб
1
import torch
2
from kornia.constants import SamplePadding
3
from kornia.augmentation import RandomAffine, CenterCrop
4

5

6
class FakeFakesGenerator:
7
    def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
8
        self.grad_aug = RandomAffine(degrees=360,
9
                                     translate=0.2,
10
                                     padding_mode=SamplePadding.REFLECTION,
11
                                     keepdim=False,
12
                                     p=1)
13
        self.img_aug = RandomAffine(degrees=img_aug_degree,
14
                                    translate=img_aug_translate,
15
                                    padding_mode=SamplePadding.REFLECTION,
16
                                    keepdim=True,
17
                                    p=1)
18
        self.aug_proba = aug_proba
19

20
    def __call__(self, input_images, masks):
21
        blend_masks = self._fill_masks_with_gradient(masks)
22
        blend_target = self._make_blend_target(input_images)
23
        result = input_images * (1 - blend_masks) + blend_target * blend_masks
24
        return result, blend_masks
25

26
    def _make_blend_target(self, input_images):
27
        batch_size = input_images.shape[0]
28
        permuted = input_images[torch.randperm(batch_size)]
29
        augmented = self.img_aug(input_images)
30
        is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
31
        result = augmented * is_aug + permuted * (1 - is_aug)
32
        return result
33

34
    def _fill_masks_with_gradient(self, masks):
35
        batch_size, _, height, width = masks.shape
36
        grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
37
            .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
38
        grad = self.grad_aug(grad)
39
        grad = CenterCrop((height, width))(grad)
40
        grad *= masks
41

42
        grad_for_min = grad + (1 - masks) * 10
43
        grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
44
        grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
45
        grad.clamp_(min=0, max=1)
46

47
        return grad
48

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.