lama

Форк
0
/
to_jit.py 
76 строк · 2.1 Кб
1
import os
2
from pathlib import Path
3

4
import hydra
5
import torch
6
import yaml
7
from omegaconf import OmegaConf
8
from torch import nn
9

10
from saicinpainting.training.trainers import load_checkpoint
11
from saicinpainting.utils import register_debug_signal_handlers
12

13

14
class JITWrapper(nn.Module):
15
    def __init__(self, model):
16
        super().__init__()
17
        self.model = model
18

19
    def forward(self, image, mask):
20
        batch = {
21
            "image": image,
22
            "mask": mask
23
        }
24
        out = self.model(batch)
25
        return out["inpainted"]
26

27

28
@hydra.main(config_path="../configs/prediction", config_name="default.yaml")
29
def main(predict_config: OmegaConf):
30
    if sys.platform != 'win32':
31
        register_debug_signal_handlers()  # kill -10 <pid> will result in traceback dumped into log
32

33
    train_config_path = os.path.join(predict_config.model.path, "config.yaml")
34
    with open(train_config_path, "r") as f:
35
        train_config = OmegaConf.create(yaml.safe_load(f))
36

37
    train_config.training_model.predict_only = True
38
    train_config.visualizer.kind = "noop"
39

40
    checkpoint_path = os.path.join(
41
        predict_config.model.path, "models", predict_config.model.checkpoint
42
    )
43
    model = load_checkpoint(
44
        train_config, checkpoint_path, strict=False, map_location="cpu"
45
    )
46
    model.eval()
47
    jit_model_wrapper = JITWrapper(model)
48

49
    image = torch.rand(1, 3, 120, 120)
50
    mask = torch.rand(1, 1, 120, 120)
51
    output = jit_model_wrapper(image, mask)
52

53
    if torch.cuda.is_available():
54
        device = torch.device("cuda")
55
    else:
56
        device = torch.device("cpu")
57

58
    image = image.to(device)
59
    mask = mask.to(device)
60
    traced_model = torch.jit.trace(jit_model_wrapper, (image, mask), strict=False).to(device)
61

62
    save_path = Path(predict_config.save_path)
63
    save_path.parent.mkdir(parents=True, exist_ok=True)
64

65
    print(f"Saving big-lama.pt model to {save_path}")
66
    traced_model.save(save_path)
67

68
    print(f"Checking jit model output...")
69
    jit_model = torch.jit.load(str(save_path))
70
    jit_output = jit_model(image, mask)
71
    diff = (output - jit_output).abs().sum()
72
    print(f"diff: {diff}")
73

74

75
if __name__ == "__main__":
76
    main()
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.