google-research

Форк
0
/
single_image_decomposition.ipynb 
405 строк · 13.3 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "markdown",
5
      "id": "K5zCsXptwfL9",
6
      "metadata": {
7
        "id": "K5zCsXptwfL9"
8
      },
9
      "source": [
10
        "Copyright 2023 Google LLC.\n",
11
        "\n",
12
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
13
      ]
14
    },
15
    {
16
      "cell_type": "code",
17
      "execution_count": null,
18
      "id": "e0b13011",
19
      "metadata": {
20
        "id": "e0b13011"
21
      },
22
      "outputs": [],
23
      "source": [
24
        "import shutil\n",
25
        "\n",
26
        "import numpy as np\n",
27
        "import os\n",
28
        "from PIL import Image\n",
29
        "import sys\n",
30
        "from shutil import copyfile\n",
31
        "from pathlib import Path\n",
32
        "\n",
33
        "from diffusers.schedulers import LMSDiscreteScheduler\n",
34
        "from diffusers import StableDiffusionPipeline\n",
35
        "\n",
36
        "\n",
37
        "import torch\n",
38
        "\n",
39
        "import torchvision.transforms as transforms\n",
40
        "from transformers import CLIPProcessor, CLIPModel, AutoTokenizer\n",
41
        "\n",
42
        "import glob\n",
43
        "import argparse"
44
      ]
45
    },
46
    {
47
      "cell_type": "markdown",
48
      "id": "d88f5d8f",
49
      "metadata": {
50
        "id": "d88f5d8f"
51
      },
52
      "source": [
53
        "## Choose concept and seed"
54
      ]
55
    },
56
    {
57
      "cell_type": "code",
58
      "execution_count": null,
59
      "id": "7ed9b6fe",
60
      "metadata": {
61
        "id": "7ed9b6fe"
62
      },
63
      "outputs": [],
64
      "source": [
65
        "concept = 'corn'\n",
66
        "target_seed = 55\n",
67
        "folder = f'./{concept}'\n",
68
        "prompt = f'a photo of a '\n",
69
        "num_inference_steps = 25"
70
      ]
71
    },
72
    {
73
      "cell_type": "markdown",
74
      "id": "41911d50",
75
      "metadata": {
76
        "id": "41911d50"
77
      },
78
      "source": [
79
        "## Load model"
80
      ]
81
    },
82
    {
83
      "cell_type": "code",
84
      "execution_count": null,
85
      "id": "06939ef6",
86
      "metadata": {
87
        "id": "06939ef6"
88
      },
89
      "outputs": [],
90
      "source": [
91
        "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
92
        "pipe.to(\"cuda\")\n",
93
        "pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\n",
94
        "pipe.set_progress_bar_config(disable=True)\n",
95
        "pipe.tokenizer.add_tokens('\u003c\u003e')\n",
96
        "trained_id = pipe.tokenizer.convert_tokens_to_ids('\u003c\u003e')\n",
97
        "pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))\n",
98
        "_ = pipe.text_encoder.get_input_embeddings().weight.requires_grad_(False)\n",
99
        "\n",
100
        "\n",
101
        "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to('cuda')\n",
102
        "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
103
        "\n",
104
        "clip_tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
105
        "\n",
106
        "transform_tensor = transforms.Compose([\n",
107
        "    transforms.ToTensor(),\n",
108
        "])"
109
      ]
110
    },
111
    {
112
      "cell_type": "markdown",
113
      "id": "ca5eaa50",
114
      "metadata": {
115
        "id": "ca5eaa50"
116
      },
117
      "source": [
118
        "## Auxiliary functions"
119
      ]
120
    },
121
    {
122
      "cell_type": "code",
123
      "execution_count": null,
124
      "id": "c25bbe5b",
125
      "metadata": {
126
        "id": "c25bbe5b"
127
      },
128
      "outputs": [],
129
      "source": [
130
        "def clip_transform(image_tensor):\n",
131
        "    image_tensor = torch.nn.functional.interpolate(image_tensor, size=(224, 224), mode='bicubic',\n",
132
        "                                                   align_corners=False)\n",
133
        "    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
134
        "                                      std=[0.26862954, 0.26130258, 0.27577711])\n",
135
        "    image_tensor = normalize(image_tensor)\n",
136
        "    return image_tensor\n",
137
        "\n",
138
        "def load_alphas(alphas_projection, token_embeddings, seed, prompt):\n",
139
        "    alphas_copy = alphas_projection.clone()\n",
140
        "    # embeddings_mat = token_embeddings[dictionary]\n",
141
        "    embedding = torch.matmul(alphas_copy, token_embeddings)\n",
142
        "    embedding = torch.mul(embedding, 1 / embedding.norm())\n",
143
        "    embedding = torch.mul(embedding, avg_norm)\n",
144
        "    pipe.text_encoder.text_model.embeddings.token_embedding.weight[trained_id] = torch.nn.Parameter(\n",
145
        "        embedding)\n",
146
        "    generator = torch.Generator(\"cuda\").manual_seed(seed)\n",
147
        "    return pipe(prompt, guidance_scale=7.5,\n",
148
        "                generator=generator,\n",
149
        "                return_dict=False,\n",
150
        "                num_images_per_prompt=1,\n",
151
        "                num_inference_steps=num_inference_steps)[0][0]"
152
      ]
153
    },
154
    {
155
      "cell_type": "markdown",
156
      "id": "cc1e9ff6",
157
      "metadata": {
158
        "id": "cc1e9ff6"
159
      },
160
      "source": [
161
        "# Load decomposition from folder"
162
      ]
163
    },
164
    {
165
      "cell_type": "code",
166
      "execution_count": null,
167
      "id": "a5a80831",
168
      "metadata": {
169
        "id": "a5a80831"
170
      },
171
      "outputs": [],
172
      "source": [
173
        "concept_nu = concept.replace('_', ' ')\n",
174
        "concept_u = concept.replace(' ', '_')\n",
175
        "\n",
176
        "orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()\n",
177
        "norms = [i.norm().item() for i in orig_embeddings]\n",
178
        "avg_norm = np.mean(norms)\n",
179
        "\n",
180
        "alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)\n",
181
        "\n",
182
        "dictionary = torch.load(f'{folder}/output/dictionary.pt')\n",
183
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)\n",
184
        "alpha_ids = []\n",
185
        "num_alphas = 50\n",
186
        "for i, idx in enumerate(sorted_indices[:num_alphas]):\n",
187
        "    alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))\n",
188
        "alphas = torch.zeros(orig_embeddings.shape[0]).cuda()\n",
189
        "top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]\n",
190
        "for i, index in enumerate(top_word_idx):\n",
191
        "    alphas[index] = alphas_dict[sorted_indices[i]]\n",
192
        "\n",
193
        "clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors=\"pt\").to('cuda')\n",
194
        "clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)\n",
195
        "\n",
196
        "clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors=\"pt\").to('cuda')\n",
197
        "clip_text_features = clip_model.get_text_features(**clip_text_inputs)\n",
198
        "clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /\n",
199
        "                         torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),\n",
200
        "                                      clip_text_features.norm(dim=1).unsqueeze(0)))\n",
201
        "\n",
202
        "concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)\n",
203
        "similar_words = (np.array(concept_words_similarity.detach().cpu()) \u003e 0.92).nonzero()[0]\n",
204
        "clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) \u003e 0.95)\n",
205
        "\n",
206
        "# Zero-out similar words\n",
207
        "for i in similar_words:\n",
208
        "    alphas[top_word_idx[i]] = 0"
209
      ]
210
    },
211
    {
212
      "cell_type": "markdown",
213
      "id": "5813ed86",
214
      "metadata": {
215
        "id": "5813ed86"
216
      },
217
      "source": [
218
        "### Visualize ground truth concept image"
219
      ]
220
    },
221
    {
222
      "cell_type": "code",
223
      "execution_count": null,
224
      "id": "5748dcb1",
225
      "metadata": {
226
        "id": "5748dcb1",
227
        "scrolled": false
228
      },
229
      "outputs": [],
230
      "source": [
231
        "generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
232
        "orig_image = pipe(f'a photo of a {concept}', guidance_scale=7.5,\n",
233
        "                generator=generator,\n",
234
        "                return_dict=False,\n",
235
        "                num_images_per_prompt=1,\n",
236
        "                num_inference_steps=num_inference_steps)[0][0]\n",
237
        "orig_image.resize((224,224))"
238
      ]
239
    },
240
    {
241
      "cell_type": "markdown",
242
      "id": "bf62ce03",
243
      "metadata": {
244
        "id": "bf62ce03"
245
      },
246
      "source": [
247
        "### Visualize decomposition image"
248
      ]
249
    },
250
    {
251
      "cell_type": "code",
252
      "execution_count": null,
253
      "id": "e405e39e",
254
      "metadata": {
255
        "id": "e405e39e"
256
      },
257
      "outputs": [],
258
      "source": [
259
        "image = load_alphas(alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
260
        "image.resize((224,224))"
261
      ]
262
    },
263
    {
264
      "cell_type": "markdown",
265
      "id": "de48de44",
266
      "metadata": {
267
        "id": "de48de44"
268
      },
269
      "source": [
270
        "## Single-image decomposition code"
271
      ]
272
    },
273
    {
274
      "cell_type": "markdown",
275
      "id": "29066be2",
276
      "metadata": {
277
        "id": "29066be2"
278
      },
279
      "source": [
280
        "### Iteratively remove features from the decomposition"
281
      ]
282
    },
283
    {
284
      "cell_type": "code",
285
      "execution_count": null,
286
      "id": "34945bd7",
287
      "metadata": {
288
        "id": "34945bd7",
289
        "scrolled": true
290
      },
291
      "outputs": [],
292
      "source": [
293
        "with torch.no_grad():\n",
294
        "        final_alphas = alphas.clone()\n",
295
        "        target_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
296
        "        target_clip = clip_model.get_image_features(target_clip)\n",
297
        "        next_indices = []\n",
298
        "        removed = True\n",
299
        "        saving_images = False\n",
300
        "        indices = np.arange(num_alphas)[::-1]\n",
301
        "\n",
302
        "        while removed:\n",
303
        "            removed = False\n",
304
        "            for idx in indices:\n",
305
        "                temp = final_alphas.clone()\n",
306
        "                temp[top_word_idx[idx]] = 0\n",
307
        "                # Also remove similar words\n",
308
        "                for similar_idx in clip_words_similarity[idx].nonzero()[0]:\n",
309
        "                    temp[top_word_idx[similar_idx]] = 0\n",
310
        "                image = load_alphas(temp, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
311
        "\n",
312
        "                curr_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
313
        "                curr_clip = clip_model.get_image_features(curr_clip)\n",
314
        "                similarity = torch.cosine_similarity(target_clip, curr_clip).item()\n",
315
        "                if similarity \u003e 0.93:\n",
316
        "                    print(f\"removing token in idx: \", idx)\n",
317
        "                    final_alphas = temp.clone()\n",
318
        "                    removed = True\n",
319
        "                else:\n",
320
        "                    print(f\"similarity: {similarity} keeping token in idx: \", idx)\n",
321
        "                    next_indices.append(idx)\n",
322
        "            indices = next_indices\n",
323
        "            next_indices = []"
324
      ]
325
    },
326
    {
327
      "cell_type": "markdown",
328
      "id": "87e5fe1a",
329
      "metadata": {
330
        "id": "87e5fe1a"
331
      },
332
      "source": [
333
        "### Visualize image after removing features"
334
      ]
335
    },
336
    {
337
      "cell_type": "code",
338
      "execution_count": null,
339
      "id": "f74f42e6",
340
      "metadata": {
341
        "id": "f74f42e6"
342
      },
343
      "outputs": [],
344
      "source": [
345
        "image_decomp = load_alphas(final_alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
346
        "image_decomp.resize((224,224))"
347
      ]
348
    },
349
    {
350
      "cell_type": "markdown",
351
      "id": "6c922333",
352
      "metadata": {
353
        "id": "6c922333"
354
      },
355
      "source": [
356
        "### Visualize the remaining image features"
357
      ]
358
    },
359
    {
360
      "cell_type": "code",
361
      "execution_count": null,
362
      "id": "02ee962b",
363
      "metadata": {
364
        "id": "02ee962b"
365
      },
366
      "outputs": [],
367
      "source": [
368
        "remaining_features = torch.nonzero(final_alphas).flatten()\n",
369
        "for feature in remaining_features:\n",
370
        "    print(\"feature: \", pipe.tokenizer.decode(feature))\n",
371
        "    generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
372
        "    feature_visualization = pipe(f'a photo of a {pipe.tokenizer.decode(feature)}', guidance_scale=7.5,\n",
373
        "                    generator=generator,\n",
374
        "                    return_dict=False,\n",
375
        "                    num_images_per_prompt=1,\n",
376
        "                    num_inference_steps=num_inference_steps)[0][0]\n",
377
        "    display(feature_visualization.resize((224,224)))"
378
      ]
379
    }
380
  ],
381
  "metadata": {
382
    "colab": {
383
      "provenance": []
384
    },
385
    "kernelspec": {
386
      "display_name": "Python 3 (ipykernel)",
387
      "language": "python",
388
      "name": "python3"
389
    },
390
    "language_info": {
391
      "codemirror_mode": {
392
        "name": "ipython",
393
        "version": 3
394
      },
395
      "file_extension": ".py",
396
      "mimetype": "text/x-python",
397
      "name": "python",
398
      "nbconvert_exporter": "python",
399
      "pygments_lexer": "ipython3",
400
      "version": "3.10.9"
401
    }
402
  },
403
  "nbformat": 4,
404
  "nbformat_minor": 5
405
}
406

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.