google-research
405 строк · 13.3 Кб
1{
2"cells": [
3{
4"cell_type": "markdown",
5"id": "K5zCsXptwfL9",
6"metadata": {
7"id": "K5zCsXptwfL9"
8},
9"source": [
10"Copyright 2023 Google LLC.\n",
11"\n",
12"Licensed under the Apache License, Version 2.0 (the \"License\");"
13]
14},
15{
16"cell_type": "code",
17"execution_count": null,
18"id": "e0b13011",
19"metadata": {
20"id": "e0b13011"
21},
22"outputs": [],
23"source": [
24"import shutil\n",
25"\n",
26"import numpy as np\n",
27"import os\n",
28"from PIL import Image\n",
29"import sys\n",
30"from shutil import copyfile\n",
31"from pathlib import Path\n",
32"\n",
33"from diffusers.schedulers import LMSDiscreteScheduler\n",
34"from diffusers import StableDiffusionPipeline\n",
35"\n",
36"\n",
37"import torch\n",
38"\n",
39"import torchvision.transforms as transforms\n",
40"from transformers import CLIPProcessor, CLIPModel, AutoTokenizer\n",
41"\n",
42"import glob\n",
43"import argparse"
44]
45},
46{
47"cell_type": "markdown",
48"id": "d88f5d8f",
49"metadata": {
50"id": "d88f5d8f"
51},
52"source": [
53"## Choose concept and seed"
54]
55},
56{
57"cell_type": "code",
58"execution_count": null,
59"id": "7ed9b6fe",
60"metadata": {
61"id": "7ed9b6fe"
62},
63"outputs": [],
64"source": [
65"concept = 'corn'\n",
66"target_seed = 55\n",
67"folder = f'./{concept}'\n",
68"prompt = f'a photo of a '\n",
69"num_inference_steps = 25"
70]
71},
72{
73"cell_type": "markdown",
74"id": "41911d50",
75"metadata": {
76"id": "41911d50"
77},
78"source": [
79"## Load model"
80]
81},
82{
83"cell_type": "code",
84"execution_count": null,
85"id": "06939ef6",
86"metadata": {
87"id": "06939ef6"
88},
89"outputs": [],
90"source": [
91"pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
92"pipe.to(\"cuda\")\n",
93"pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\n",
94"pipe.set_progress_bar_config(disable=True)\n",
95"pipe.tokenizer.add_tokens('\u003c\u003e')\n",
96"trained_id = pipe.tokenizer.convert_tokens_to_ids('\u003c\u003e')\n",
97"pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))\n",
98"_ = pipe.text_encoder.get_input_embeddings().weight.requires_grad_(False)\n",
99"\n",
100"\n",
101"clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to('cuda')\n",
102"clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
103"\n",
104"clip_tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
105"\n",
106"transform_tensor = transforms.Compose([\n",
107" transforms.ToTensor(),\n",
108"])"
109]
110},
111{
112"cell_type": "markdown",
113"id": "ca5eaa50",
114"metadata": {
115"id": "ca5eaa50"
116},
117"source": [
118"## Auxiliary functions"
119]
120},
121{
122"cell_type": "code",
123"execution_count": null,
124"id": "c25bbe5b",
125"metadata": {
126"id": "c25bbe5b"
127},
128"outputs": [],
129"source": [
130"def clip_transform(image_tensor):\n",
131" image_tensor = torch.nn.functional.interpolate(image_tensor, size=(224, 224), mode='bicubic',\n",
132" align_corners=False)\n",
133" normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
134" std=[0.26862954, 0.26130258, 0.27577711])\n",
135" image_tensor = normalize(image_tensor)\n",
136" return image_tensor\n",
137"\n",
138"def load_alphas(alphas_projection, token_embeddings, seed, prompt):\n",
139" alphas_copy = alphas_projection.clone()\n",
140" # embeddings_mat = token_embeddings[dictionary]\n",
141" embedding = torch.matmul(alphas_copy, token_embeddings)\n",
142" embedding = torch.mul(embedding, 1 / embedding.norm())\n",
143" embedding = torch.mul(embedding, avg_norm)\n",
144" pipe.text_encoder.text_model.embeddings.token_embedding.weight[trained_id] = torch.nn.Parameter(\n",
145" embedding)\n",
146" generator = torch.Generator(\"cuda\").manual_seed(seed)\n",
147" return pipe(prompt, guidance_scale=7.5,\n",
148" generator=generator,\n",
149" return_dict=False,\n",
150" num_images_per_prompt=1,\n",
151" num_inference_steps=num_inference_steps)[0][0]"
152]
153},
154{
155"cell_type": "markdown",
156"id": "cc1e9ff6",
157"metadata": {
158"id": "cc1e9ff6"
159},
160"source": [
161"# Load decomposition from folder"
162]
163},
164{
165"cell_type": "code",
166"execution_count": null,
167"id": "a5a80831",
168"metadata": {
169"id": "a5a80831"
170},
171"outputs": [],
172"source": [
173"concept_nu = concept.replace('_', ' ')\n",
174"concept_u = concept.replace(' ', '_')\n",
175"\n",
176"orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()\n",
177"norms = [i.norm().item() for i in orig_embeddings]\n",
178"avg_norm = np.mean(norms)\n",
179"\n",
180"alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)\n",
181"\n",
182"dictionary = torch.load(f'{folder}/output/dictionary.pt')\n",
183"sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)\n",
184"alpha_ids = []\n",
185"num_alphas = 50\n",
186"for i, idx in enumerate(sorted_indices[:num_alphas]):\n",
187" alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))\n",
188"alphas = torch.zeros(orig_embeddings.shape[0]).cuda()\n",
189"top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]\n",
190"for i, index in enumerate(top_word_idx):\n",
191" alphas[index] = alphas_dict[sorted_indices[i]]\n",
192"\n",
193"clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors=\"pt\").to('cuda')\n",
194"clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)\n",
195"\n",
196"clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors=\"pt\").to('cuda')\n",
197"clip_text_features = clip_model.get_text_features(**clip_text_inputs)\n",
198"clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /\n",
199" torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),\n",
200" clip_text_features.norm(dim=1).unsqueeze(0)))\n",
201"\n",
202"concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)\n",
203"similar_words = (np.array(concept_words_similarity.detach().cpu()) \u003e 0.92).nonzero()[0]\n",
204"clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) \u003e 0.95)\n",
205"\n",
206"# Zero-out similar words\n",
207"for i in similar_words:\n",
208" alphas[top_word_idx[i]] = 0"
209]
210},
211{
212"cell_type": "markdown",
213"id": "5813ed86",
214"metadata": {
215"id": "5813ed86"
216},
217"source": [
218"### Visualize ground truth concept image"
219]
220},
221{
222"cell_type": "code",
223"execution_count": null,
224"id": "5748dcb1",
225"metadata": {
226"id": "5748dcb1",
227"scrolled": false
228},
229"outputs": [],
230"source": [
231"generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
232"orig_image = pipe(f'a photo of a {concept}', guidance_scale=7.5,\n",
233" generator=generator,\n",
234" return_dict=False,\n",
235" num_images_per_prompt=1,\n",
236" num_inference_steps=num_inference_steps)[0][0]\n",
237"orig_image.resize((224,224))"
238]
239},
240{
241"cell_type": "markdown",
242"id": "bf62ce03",
243"metadata": {
244"id": "bf62ce03"
245},
246"source": [
247"### Visualize decomposition image"
248]
249},
250{
251"cell_type": "code",
252"execution_count": null,
253"id": "e405e39e",
254"metadata": {
255"id": "e405e39e"
256},
257"outputs": [],
258"source": [
259"image = load_alphas(alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
260"image.resize((224,224))"
261]
262},
263{
264"cell_type": "markdown",
265"id": "de48de44",
266"metadata": {
267"id": "de48de44"
268},
269"source": [
270"## Single-image decomposition code"
271]
272},
273{
274"cell_type": "markdown",
275"id": "29066be2",
276"metadata": {
277"id": "29066be2"
278},
279"source": [
280"### Iteratively remove features from the decomposition"
281]
282},
283{
284"cell_type": "code",
285"execution_count": null,
286"id": "34945bd7",
287"metadata": {
288"id": "34945bd7",
289"scrolled": true
290},
291"outputs": [],
292"source": [
293"with torch.no_grad():\n",
294" final_alphas = alphas.clone()\n",
295" target_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
296" target_clip = clip_model.get_image_features(target_clip)\n",
297" next_indices = []\n",
298" removed = True\n",
299" saving_images = False\n",
300" indices = np.arange(num_alphas)[::-1]\n",
301"\n",
302" while removed:\n",
303" removed = False\n",
304" for idx in indices:\n",
305" temp = final_alphas.clone()\n",
306" temp[top_word_idx[idx]] = 0\n",
307" # Also remove similar words\n",
308" for similar_idx in clip_words_similarity[idx].nonzero()[0]:\n",
309" temp[top_word_idx[similar_idx]] = 0\n",
310" image = load_alphas(temp, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
311"\n",
312" curr_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
313" curr_clip = clip_model.get_image_features(curr_clip)\n",
314" similarity = torch.cosine_similarity(target_clip, curr_clip).item()\n",
315" if similarity \u003e 0.93:\n",
316" print(f\"removing token in idx: \", idx)\n",
317" final_alphas = temp.clone()\n",
318" removed = True\n",
319" else:\n",
320" print(f\"similarity: {similarity} keeping token in idx: \", idx)\n",
321" next_indices.append(idx)\n",
322" indices = next_indices\n",
323" next_indices = []"
324]
325},
326{
327"cell_type": "markdown",
328"id": "87e5fe1a",
329"metadata": {
330"id": "87e5fe1a"
331},
332"source": [
333"### Visualize image after removing features"
334]
335},
336{
337"cell_type": "code",
338"execution_count": null,
339"id": "f74f42e6",
340"metadata": {
341"id": "f74f42e6"
342},
343"outputs": [],
344"source": [
345"image_decomp = load_alphas(final_alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
346"image_decomp.resize((224,224))"
347]
348},
349{
350"cell_type": "markdown",
351"id": "6c922333",
352"metadata": {
353"id": "6c922333"
354},
355"source": [
356"### Visualize the remaining image features"
357]
358},
359{
360"cell_type": "code",
361"execution_count": null,
362"id": "02ee962b",
363"metadata": {
364"id": "02ee962b"
365},
366"outputs": [],
367"source": [
368"remaining_features = torch.nonzero(final_alphas).flatten()\n",
369"for feature in remaining_features:\n",
370" print(\"feature: \", pipe.tokenizer.decode(feature))\n",
371" generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
372" feature_visualization = pipe(f'a photo of a {pipe.tokenizer.decode(feature)}', guidance_scale=7.5,\n",
373" generator=generator,\n",
374" return_dict=False,\n",
375" num_images_per_prompt=1,\n",
376" num_inference_steps=num_inference_steps)[0][0]\n",
377" display(feature_visualization.resize((224,224)))"
378]
379}
380],
381"metadata": {
382"colab": {
383"provenance": []
384},
385"kernelspec": {
386"display_name": "Python 3 (ipykernel)",
387"language": "python",
388"name": "python3"
389},
390"language_info": {
391"codemirror_mode": {
392"name": "ipython",
393"version": 3
394},
395"file_extension": ".py",
396"mimetype": "text/x-python",
397"name": "python",
398"nbconvert_exporter": "python",
399"pygments_lexer": "ipython3",
400"version": "3.10.9"
401}
402},
403"nbformat": 4,
404"nbformat_minor": 5
405}
406