google-research

Форк
0
42 строки · 1.7 Кб
1
--- external/gsn/models/model_utils.py	2023-03-09 18:12:06.857800424 +0000
2
+++ external_reference/gsn/models/model_utils.py	2023-03-09 18:07:58.345065235 +0000
3
@@ -1,6 +1,7 @@
4
 import torch
5
 from torch import nn
6
 import numpy as np
7
+from torch_utils import persistence
8
 
9
 
10
 def flatten_trajectories(data):
11
@@ -68,7 +69,7 @@
12
         if nerf_out_res:
13
             self.nerf_out_res = nerf_out_res
14
 
15
-
16
+@persistence.persistent_class
17
 class TrajectorySampler(nn.Module):
18
     """Trajectory sampler.
19
 
20
@@ -99,7 +100,8 @@
21
     def __init__(self, real_Rts, mode='sample', num_bins=10, alpha_activation='relu', jitter_range=0):
22
         super().__init__()
23
 
24
-        self.real_Rts = nn.Parameter(real_Rts, requires_grad=False)  # shape [n_trajectories, seq_len, 4, 4]
25
+        # CHANGED: make a buffer s.t. does not update in GAN updates
26
+        self.register_buffer('real_Rts', real_Rts)
27
         self.mode = mode
28
         self.num_bins = num_bins
29
         self.alpha_activation = alpha_activation
30
@@ -107,9 +109,10 @@
31
 
32
         # convert Rt matrices to camera pose matrices, then extract translation component
33
         # make sure Rts are float, since inverse doesn't work with FP16
34
-        self.real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
35
+        real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
36
         # shape [n_trajectories, seq_len, 3]
37
-        self.real_trajectories = nn.Parameter(self.real_trajectories, requires_grad=False)
38
+        # CHANGED: make a buffer s.t. does not update in GAN updates
39
+        self.register_buffer('real_trajectories', real_trajectories)
40
         self.seq_len = self.real_trajectories.shape[1]
41
 
42
         if mode == 'bin':
43

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.