google-research

Форк
0
/
visualize_concept.ipynb 
261 строка · 7.9 Кб
1
{
2
  "cells": [
3
    {
4
      "cell_type": "markdown",
5
      "id": "IcwwhIrLwMZc",
6
      "metadata": {
7
        "id": "IcwwhIrLwMZc"
8
      },
9
      "source": [
10
        "Copyright 2023 Google LLC.\n",
11
        "\n",
12
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
13
      ]
14
    },
15
    {
16
      "cell_type": "code",
17
      "execution_count": null,
18
      "id": "258ca2e3",
19
      "metadata": {
20
        "id": "258ca2e3"
21
      },
22
      "outputs": [],
23
      "source": [
24
        "from diffusers import StableDiffusionPipeline\n",
25
        "from diffusers.schedulers import LMSDiscreteScheduler\n",
26
        "from random import randrange\n",
27
        "import torch\n",
28
        "from diffusers import StableDiffusionPipeline\n",
29
        "from diffusers.schedulers import LMSDiscreteScheduler, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler\n",
30
        "import timm\n",
31
        "import torchvision.transforms as transforms\n",
32
        "from random import randrange\n",
33
        "import requests\n",
34
        "import torch.optim as optim\n",
35
        "import numpy as np\n",
36
        "from PIL import Image, ImageDraw\n",
37
        "import numpy as np\n",
38
        "import argparse\n",
39
        "import os\n",
40
        "import glob\n",
41
        "from pathlib import Path"
42
      ]
43
    },
44
    {
45
      "cell_type": "code",
46
      "execution_count": null,
47
      "id": "f46b8c8c",
48
      "metadata": {
49
        "id": "f46b8c8c"
50
      },
51
      "outputs": [],
52
      "source": [
53
        "# initialize stable diffusion pipeline\n",
54
        "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
55
        "pipe.to(\"cuda\")\n",
56
        "scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\n",
57
        "pipe.scheduler = scheduler\n",
58
        "orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()\n",
59
        "pipe.text_encoder.text_model.embeddings.token_embedding.weight.requires_grad_(False)"
60
      ]
61
    },
62
    {
63
      "cell_type": "markdown",
64
      "id": "14ba3600",
65
      "metadata": {
66
        "id": "14ba3600"
67
      },
68
      "source": [
69
        "# Load decomposition results"
70
      ]
71
    },
72
    {
73
      "cell_type": "code",
74
      "execution_count": null,
75
      "id": "690fa1ba",
76
      "metadata": {
77
        "id": "690fa1ba"
78
      },
79
      "outputs": [],
80
      "source": [
81
        "concept = 'dog'\n",
82
        "folder = f'./{concept}'\n",
83
        "# load coefficients\n",
84
        "alphas_dict = torch.load(f'{folder}/best_alphas.pt').detach_().requires_grad_(False)\n",
85
        "# load vocabulary\n",
86
        "dictionary = torch.load(f'{folder}/dictionary.pt')"
87
      ]
88
    },
89
    {
90
      "cell_type": "markdown",
91
      "id": "4a77cc1a",
92
      "metadata": {
93
        "id": "4a77cc1a"
94
      },
95
      "source": [
96
        "# Visualize top coefficients and top tokens"
97
      ]
98
    },
99
    {
100
      "cell_type": "code",
101
      "execution_count": null,
102
      "id": "2cef1f3c",
103
      "metadata": {
104
        "id": "2cef1f3c"
105
      },
106
      "outputs": [],
107
      "source": [
108
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)\n",
109
        "num_indices=10\n",
110
        "top_indices_orig_dict = [dictionary[i] for i in sorted_indices[:num_indices]]\n",
111
        "print(\"top coefficients: \", sorted_alphas[:num_indices].cpu().numpy())\n",
112
        "alpha_ids = [pipe.tokenizer.decode(idx) for idx in top_indices_orig_dict]\n",
113
        "print(\"top tokens: \", alpha_ids)"
114
      ]
115
    },
116
    {
117
      "cell_type": "markdown",
118
      "id": "bdac204b",
119
      "metadata": {
120
        "id": "bdac204b"
121
      },
122
      "source": [
123
        "# Extract top 50 tokens"
124
      ]
125
    },
126
    {
127
      "cell_type": "code",
128
      "execution_count": null,
129
      "id": "f3dbfc77",
130
      "metadata": {
131
        "id": "f3dbfc77"
132
      },
133
      "outputs": [],
134
      "source": [
135
        "num_tokens = 50\n",
136
        "alphas = torch.zeros(orig_embeddings.shape[0]).cuda()\n",
137
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict.abs(), descending=True)\n",
138
        "top_word_idx = [dictionary[i] for i in sorted_indices[:num_tokens]]\n",
139
        "for i,index in enumerate(top_word_idx):\n",
140
        "    alphas[index] = alphas_dict[sorted_indices[i]]\n",
141
        "\n",
142
        "# add placeholder for w^*\n",
143
        "placeholder_token = '\u003c\u003e'\n",
144
        "pipe.tokenizer.add_tokens(placeholder_token)\n",
145
        "placeholder_token_id = pipe.tokenizer.convert_tokens_to_ids(placeholder_token)\n",
146
        "pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))\n",
147
        "token_embeds = pipe.text_encoder.get_input_embeddings().weight.detach().requires_grad_(False)\n",
148
        "\n",
149
        "# compute w^* and normalize its embedding\n",
150
        "learned_embedding = torch.matmul(alphas, orig_embeddings).flatten()\n",
151
        "norms = [i.norm().item() for i in orig_embeddings]\n",
152
        "avg_norm = np.mean(norms)\n",
153
        "learned_embedding /= learned_embedding.norm()\n",
154
        "learned_embedding *= avg_norm\n",
155
        "\n",
156
        "# add w^* to vocabulary\n",
157
        "token_embeds[placeholder_token_id] = torch.nn.Parameter(learned_embedding)"
158
      ]
159
    },
160
    {
161
      "cell_type": "code",
162
      "execution_count": null,
163
      "id": "b050e78f",
164
      "metadata": {
165
        "id": "b050e78f"
166
      },
167
      "outputs": [],
168
      "source": [
169
        "import math\n",
170
        "def get_image_grid(images) -\u003e Image:\n",
171
        "    num_images = len(images)\n",
172
        "    cols = int(math.ceil(math.sqrt(num_images)))\n",
173
        "    rows = int(math.ceil(num_images / cols))\n",
174
        "    width, height = images[0].size\n",
175
        "    grid_image = Image.new('RGB', (cols * width, rows * height))\n",
176
        "    for i, img in enumerate(images):\n",
177
        "        x = i % cols\n",
178
        "        y = i // cols\n",
179
        "        grid_image.paste(img, (x * width, y * height))\n",
180
        "    return grid_image"
181
      ]
182
    },
183
    {
184
      "cell_type": "markdown",
185
      "id": "8a24b91a",
186
      "metadata": {
187
        "id": "8a24b91a"
188
      },
189
      "source": [
190
        "# Reconstruction results- first 6 images of seed 0 (no cherry picking)"
191
      ]
192
    },
193
    {
194
      "cell_type": "code",
195
      "execution_count": null,
196
      "id": "5ab6a7a7",
197
      "metadata": {
198
        "id": "5ab6a7a7",
199
        "scrolled": false
200
      },
201
      "outputs": [],
202
      "source": [
203
        "prompt = 'a photo of a \u003c\u003e'\n",
204
        "\n",
205
        "generator = torch.Generator(\"cuda\").manual_seed(0)\n",
206
        "image = pipe(prompt,\n",
207
        "             guidance_scale=7.5,\n",
208
        "             generator=generator,\n",
209
        "             return_dict=False,\n",
210
        "             num_images_per_prompt=6,\n",
211
        "            num_inference_steps=50)\n",
212
        "display(get_image_grid(image[0]))"
213
      ]
214
    },
215
    {
216
      "cell_type": "code",
217
      "execution_count": null,
218
      "id": "dcfa03da",
219
      "metadata": {
220
        "id": "dcfa03da",
221
        "scrolled": false
222
      },
223
      "outputs": [],
224
      "source": [
225
        "prompt = 'a photo of a dog'\n",
226
        "generator = torch.Generator(\"cuda\").manual_seed(0)\n",
227
        "image = pipe(prompt,\n",
228
        "             guidance_scale=7.5,\n",
229
        "             generator=generator,\n",
230
        "             return_dict=False,\n",
231
        "             num_images_per_prompt=6,\n",
232
        "            num_inference_steps=50)\n",
233
        "display(get_image_grid(image[0]))"
234
      ]
235
    }
236
  ],
237
  "metadata": {
238
    "colab": {
239
      "provenance": []
240
    },
241
    "kernelspec": {
242
      "display_name": "Python 3 (ipykernel)",
243
      "language": "python",
244
      "name": "python3"
245
    },
246
    "language_info": {
247
      "codemirror_mode": {
248
        "name": "ipython",
249
        "version": 3
250
      },
251
      "file_extension": ".py",
252
      "mimetype": "text/x-python",
253
      "name": "python",
254
      "nbconvert_exporter": "python",
255
      "pygments_lexer": "ipython3",
256
      "version": "3.10.9"
257
    }
258
  },
259
  "nbformat": 4,
260
  "nbformat_minor": 5
261
}
262

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.