google-research
42 строки · 1.7 Кб
1--- external/gsn/models/model_utils.py 2023-03-09 18:12:06.857800424 +0000
2+++ external_reference/gsn/models/model_utils.py 2023-03-09 18:07:58.345065235 +0000
3@@ -1,6 +1,7 @@
4import torch
5from torch import nn
6import numpy as np
7+from torch_utils import persistence
8
9
10def flatten_trajectories(data):
11@@ -68,7 +69,7 @@
12if nerf_out_res:
13self.nerf_out_res = nerf_out_res
14
15-
16+@persistence.persistent_class
17class TrajectorySampler(nn.Module):
18"""Trajectory sampler.
19
20@@ -99,7 +100,8 @@
21def __init__(self, real_Rts, mode='sample', num_bins=10, alpha_activation='relu', jitter_range=0):
22super().__init__()
23
24- self.real_Rts = nn.Parameter(real_Rts, requires_grad=False) # shape [n_trajectories, seq_len, 4, 4]
25+ # CHANGED: make a buffer s.t. does not update in GAN updates
26+ self.register_buffer('real_Rts', real_Rts)
27self.mode = mode
28self.num_bins = num_bins
29self.alpha_activation = alpha_activation
30@@ -107,9 +109,10 @@
31
32# convert Rt matrices to camera pose matrices, then extract translation component
33# make sure Rts are float, since inverse doesn't work with FP16
34- self.real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
35+ real_trajectories = real_Rts.float().inverse()[:, :, :3, 3].contiguous()
36# shape [n_trajectories, seq_len, 3]
37- self.real_trajectories = nn.Parameter(self.real_trajectories, requires_grad=False)
38+ # CHANGED: make a buffer s.t. does not update in GAN updates
39+ self.register_buffer('real_trajectories', real_trajectories)
40self.seq_len = self.real_trajectories.shape[1]
41
42if mode == 'bin':
43