google-research
98 строк · 2.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Model definition."""
17# pylint: disable=g-multiple-import,g-importing-member,g-bad-import-order,missing-class-docstring
18
19from typing import Optional
20from PIL import Image
21import torch
22from torchvision import transforms
23from transformers import CLIPModel, CLIPProcessor
24
25
26class ModelZoo:
27
28def transform(self, image):
29pass
30
31def transform_tensor(self, image_tensor):
32pass
33
34def calculate_loss(
35self, output, target_images
36):
37pass
38
39def get_probability(
40self, output, target_images
41):
42pass
43
44
45class CLIPImageSimilarity(ModelZoo):
46
47def __init__(self):
48# initialize classifier
49self.clip_model = CLIPModel.from_pretrained(
50"openai/clip-vit-base-patch32"
51).to("cuda")
52self.clip_processor = CLIPProcessor.from_pretrained(
53"openai/clip-vit-base-patch32"
54)
55
56def transform(self, image):
57images_processed = self.clip_processor(images=image, return_tensors="pt")[
58"pixel_values"
59].cuda()
60return images_processed
61
62def transform_tensor(self, image_tensor):
63image_tensor = torch.nn.functional.interpolate(
64image_tensor, size=(224, 224), mode="bicubic", align_corners=False
65)
66normalize = transforms.Normalize(
67mean=[0.48145466, 0.4578275, 0.40821073],
68std=[0.26862954, 0.26130258, 0.27577711],
69)
70image_tensor = normalize(image_tensor)
71return image_tensor
72
73def calculate_loss(
74self, output, target_images
75):
76# calculate CLIP loss
77output = self.clip_model.get_image_features(output)
78# loss = -torch.cosine_similarity(output, input_clip_embedding, axis=1)
79
80mean_target_image = target_images.mean(dim=0).reshape(1, -1)
81loss = torch.mean(
82torch.cosine_similarity(
83output[None], mean_target_image[:, None], axis=2
84),
85axis=1,
86)
87loss = 1 - loss.mean()
88return loss
89
90def get_probability(
91self, output, target_images
92):
93output = self.clip_model.get_image_features(output)
94mean_target_image = target_images.mean(dim=0).reshape(1, -1)
95loss = torch.mean(
96torch.cosine_similarity(output[None], mean_target_image, axis=2), axis=1
97)
98return loss.mean()
99