lama
47 строк · 2.1 Кб
1import torch2from kornia.constants import SamplePadding3from kornia.augmentation import RandomAffine, CenterCrop4
5
6class FakeFakesGenerator:7def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):8self.grad_aug = RandomAffine(degrees=360,9translate=0.2,10padding_mode=SamplePadding.REFLECTION,11keepdim=False,12p=1)13self.img_aug = RandomAffine(degrees=img_aug_degree,14translate=img_aug_translate,15padding_mode=SamplePadding.REFLECTION,16keepdim=True,17p=1)18self.aug_proba = aug_proba19
20def __call__(self, input_images, masks):21blend_masks = self._fill_masks_with_gradient(masks)22blend_target = self._make_blend_target(input_images)23result = input_images * (1 - blend_masks) + blend_target * blend_masks24return result, blend_masks25
26def _make_blend_target(self, input_images):27batch_size = input_images.shape[0]28permuted = input_images[torch.randperm(batch_size)]29augmented = self.img_aug(input_images)30is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()31result = augmented * is_aug + permuted * (1 - is_aug)32return result33
34def _fill_masks_with_gradient(self, masks):35batch_size, _, height, width = masks.shape36grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \37.view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)38grad = self.grad_aug(grad)39grad = CenterCrop((height, width))(grad)40grad *= masks41
42grad_for_min = grad + (1 - masks) * 1043grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]44grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-645grad.clamp_(min=0, max=1)46
47return grad48