google-research

Форк
0
/
model_zoo.py 
98 строк · 2.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Model definition."""
17
# pylint: disable=g-multiple-import,g-importing-member,g-bad-import-order,missing-class-docstring
18

19
from typing import Optional
20
from PIL import Image
21
import torch
22
from torchvision import transforms
23
from transformers import CLIPModel, CLIPProcessor
24

25

26
class ModelZoo:
27

28
  def transform(self, image):
29
    pass
30

31
  def transform_tensor(self, image_tensor):
32
    pass
33

34
  def calculate_loss(
35
      self, output, target_images
36
  ):
37
    pass
38

39
  def get_probability(
40
      self, output, target_images
41
  ):
42
    pass
43

44

45
class CLIPImageSimilarity(ModelZoo):
46

47
  def __init__(self):
48
    # initialize classifier
49
    self.clip_model = CLIPModel.from_pretrained(
50
        "openai/clip-vit-base-patch32"
51
    ).to("cuda")
52
    self.clip_processor = CLIPProcessor.from_pretrained(
53
        "openai/clip-vit-base-patch32"
54
    )
55

56
  def transform(self, image):
57
    images_processed = self.clip_processor(images=image, return_tensors="pt")[
58
        "pixel_values"
59
    ].cuda()
60
    return images_processed
61

62
  def transform_tensor(self, image_tensor):
63
    image_tensor = torch.nn.functional.interpolate(
64
        image_tensor, size=(224, 224), mode="bicubic", align_corners=False
65
    )
66
    normalize = transforms.Normalize(
67
        mean=[0.48145466, 0.4578275, 0.40821073],
68
        std=[0.26862954, 0.26130258, 0.27577711],
69
    )
70
    image_tensor = normalize(image_tensor)
71
    return image_tensor
72

73
  def calculate_loss(
74
      self, output, target_images
75
  ):
76
    # calculate CLIP loss
77
    output = self.clip_model.get_image_features(output)
78
    # loss = -torch.cosine_similarity(output, input_clip_embedding, axis=1)
79

80
    mean_target_image = target_images.mean(dim=0).reshape(1, -1)
81
    loss = torch.mean(
82
        torch.cosine_similarity(
83
            output[None], mean_target_image[:, None], axis=2
84
        ),
85
        axis=1,
86
    )
87
    loss = 1 - loss.mean()
88
    return loss
89

90
  def get_probability(
91
      self, output, target_images
92
  ):
93
    output = self.clip_model.get_image_features(output)
94
    mean_target_image = target_images.mean(dim=0).reshape(1, -1)
95
    loss = torch.mean(
96
        torch.cosine_similarity(output[None], mean_target_image, axis=2), axis=1
97
    )
98
    return loss.mean()
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.