google-research

Форк
0
426 строк · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Simulation of coupled harmonic motion."""
17

18
from typing import Mapping, Optional, Sequence, Tuple
19

20
import chex
21
import jax
22
import jax.numpy as jnp
23
from matplotlib import animation
24
import matplotlib.pyplot as plt
25
import numpy as np
26

27

28
def compute_normal_modes(
29
    simulation_parameters
30
):
31
  """Returns the angular frequencies and eigenvectors for the normal modes."""
32
  m, k_wall, k_pair = (simulation_parameters["m"],
33
                       simulation_parameters["k_wall"],
34
                       simulation_parameters["k_pair"])
35
  num_trajectories = m.shape[0]
36

37
  # Construct coupling matrix.
38
  coupling_matrix = (-(k_wall + 2 * k_pair) * jnp.eye(num_trajectories) +
39
                     k_pair * jnp.ones((num_trajectories, num_trajectories)))
40
  coupling_matrix = jnp.diag(1 / m) @ coupling_matrix
41

42
  # Compute eigenvalues and eigenvectors.
43
  eigvals, eigvecs = jnp.linalg.eig(coupling_matrix)
44
  w = jnp.sqrt(-eigvals)
45
  w = jnp.real(w)
46
  eigvecs = jnp.real(eigvecs)
47
  return w, eigvecs
48

49

50
def generate_canonical_coordinates(
51
    t, simulation_parameters
52
):
53
  """Returns q (position) and p (momentum) coordinates at instant t."""
54
  w, eigvecs = compute_normal_modes(simulation_parameters)
55
  m = simulation_parameters["m"]
56
  normal_mode_simulation_parameters = {
57
      "A": simulation_parameters["A"],
58
      "phi": simulation_parameters["phi"],
59
      # We will scale momentums by mass later.
60
      "m": jnp.ones_like(m),
61
      "w": w,
62
  }
63
  normal_mode_trajectories = generate_canonical_coordinates_for_normal_mode(
64
      t, normal_mode_simulation_parameters)
65
  trajectories = jax.tree_map(lambda arr: eigvecs @ arr,
66
                              normal_mode_trajectories)
67
  positions, momentums = trajectories
68
  # Scale momentums by mass here.
69
  momentums = momentums * m
70
  return positions, momentums
71

72

73
def generate_canonical_coordinates_for_normal_mode(
74
    t,
75
    mode_simulation_parameters,
76
):
77
  """Returns q (position) and p (momentum) coordinates at instant t."""
78
  phi, a, m, w = (mode_simulation_parameters["phi"],
79
                  mode_simulation_parameters["A"],
80
                  mode_simulation_parameters["m"],
81
                  mode_simulation_parameters["w"])
82
  position = a * jnp.cos(w * t + phi)
83
  momentum = -m * w * a * jnp.sin(w * t + phi)
84
  return position, momentum
85

86

87
def _squared_l2_distance(u, v):
88
  return jnp.square(u - v).sum()
89

90

91
def compute_hamiltonian(
92
    position,
93
    momentum,
94
    simulation_parameters,
95
):
96
  """Computes the Hamiltonian at the given coordinates."""
97
  m, k_wall, k_pair = (simulation_parameters["m"],
98
                       simulation_parameters["k_wall"],
99
                       simulation_parameters["k_pair"][0])
100
  q, p = position, momentum
101
  squared_distance_matrix = jax.vmap(
102
      jax.vmap(_squared_l2_distance, in_axes=(None, 0)), in_axes=(0, None)
103
    )(q, q)
104
  squared_distances = jnp.sum(squared_distance_matrix) / 2
105
  hamiltonian = ((p**2) / (2 * m)).sum()
106
  hamiltonian += (k_wall * (q**2)).sum() / 2
107
  hamiltonian += (k_pair * squared_distances) / 2
108
  return hamiltonian
109

110

111
def plot_coordinates(positions, momentums,
112
                     simulation_parameters,
113
                     title):
114
  """Plots coordinates in the canonical basis."""
115
  assert len(positions) == len(momentums)
116

117
  qs, ps = positions, momentums
118
  qs, ps = np.asarray(qs), np.asarray(ps)
119
  if qs.ndim == 1:
120
    qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]
121

122
  assert qs.ndim == 2, f"Got positions of shape {qs.shape}."
123
  assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."
124

125
  # Create new Figure with black background
126
  fig = plt.figure(figsize=(8, 6), facecolor="black")
127

128
  # Add a subplot with no frame
129
  ax = plt.subplot(frameon=False)
130

131
  # Compute Hamiltonians.
132
  num_steps = qs.shape[0]
133
  q_max = np.max(np.abs(qs))
134
  p_max = np.max(np.abs(ps))
135
  p_scale = (q_max / p_max) / 5
136
  hs = jax.vmap(  # pytype: disable=wrong-arg-types  # numpy-scalars
137
      compute_hamiltonian, in_axes=(0, 0, None))(qs, ps, simulation_parameters)
138
  hs_formatted = np.round(hs.squeeze(), 5)
139

140
  def update(t):
141
    # Update data
142
    ax.clear()
143

144
    # 2 part titles to get different font weights
145
    ax.text(
146
        0.5,
147
        1.0,
148
        title + " ",
149
        transform=ax.transAxes,
150
        ha="center",
151
        va="bottom",
152
        color="w",
153
        family="sans-serif",
154
        fontweight="light",
155
        fontsize=16)
156
    ax.text(
157
        0.5,
158
        0.93,
159
        "VISUALIZED",
160
        transform=ax.transAxes,
161
        ha="center",
162
        va="bottom",
163
        color="w",
164
        family="sans-serif",
165
        fontweight="bold",
166
        fontsize=16)
167

168
    for qs_series, ps_series in zip(qs.T, ps.T):
169
      ax.scatter(qs_series[t], 10, marker="o", s=40, color="white")
170
      ax.annotate(
171
          r"$q$",
172
          xy=(qs_series[t], 8),
173
          ha="center",
174
          va="center",
175
          size=12,
176
          color="white")
177
      ax.annotate(
178
          r"$p$",
179
          xy=(qs_series[t], 10 - 0.15),
180
          xytext=(qs_series[t] + ps_series[t] * p_scale, 10 - 0.15),
181
          arrowprops=dict(arrowstyle="<-", color="white"),
182
          ha="center",
183
          va="center",
184
          size=12,
185
          color="white")
186

187
    ax.plot([0, 0], [5, 15], linestyle="dashed", color="white")
188

189
    ax.annotate(
190
        r"$H$ = %0.5f" % hs_formatted[t],
191
        xy=(0, 40),
192
        ha="center",
193
        va="center",
194
        size=14,
195
        color="white")
196

197
    ax.set_xlim(-(q_max * 1.1), (q_max * 1.1))
198
    ax.set_ylim(-1, 50)
199

200
    # No ticks
201
    ax.set_xticks([])
202
    ax.set_yticks([])
203

204
  # Construct the animation with the update function as the animation director.
205
  anim = animation.FuncAnimation(
206
      fig, update, frames=num_steps, interval=100, blit=False)
207
  plt.close()
208
  return anim
209

210

211
def plot_coordinates_in_phase_space(
212
    positions,
213
    momentums,
214
    simulation_parameters,
215
    title,
216
):
217
  """Plots a phase space diagram of the given coordinates."""
218
  assert len(positions) == len(momentums)
219

220
  qs, ps = positions, momentums
221
  qs, ps = np.asarray(qs), np.asarray(ps)
222
  if qs.ndim == 1:
223
    qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]
224

225
  assert qs.ndim == 2, f"Got positions of shape {qs.shape}."
226
  assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."
227

228
  # Create new Figure with black background
229
  fig = plt.figure(figsize=(8, 6), facecolor="black")
230

231
  # Add a subplot.
232
  ax = plt.subplot(facecolor="black")
233
  pos = ax.get_position()
234
  pos = [pos.x0, pos.y0 - 0.15, pos.width, pos.height]
235
  ax.set_position(pos)
236

237
  # Compute Hamiltonians.
238
  num_steps = qs.shape[0]
239
  q_max = np.max(np.abs(qs))
240
  p_max = np.max(np.abs(ps))
241
  hs = jax.vmap(  # pytype: disable=wrong-arg-types  # numpy-scalars
242
      compute_hamiltonian, in_axes=(0, 0, None))(qs, ps, simulation_parameters)
243
  hs_formatted = np.round(hs.squeeze(), 5)
244

245
  def update(t):
246
    # Update data
247
    ax.clear()
248

249
    # 2 part titles to get different font weights
250
    ax.text(
251
        0.5,
252
        0.83,
253
        title + " ",
254
        transform=fig.transFigure,
255
        ha="center",
256
        va="bottom",
257
        color="w",
258
        family="sans-serif",
259
        fontweight="light",
260
        fontsize=16)
261
    ax.text(
262
        0.5,
263
        0.78,
264
        "PHASE SPACE VISUALIZED",
265
        transform=fig.transFigure,
266
        ha="center",
267
        va="bottom",
268
        color="w",
269
        family="sans-serif",
270
        fontweight="bold",
271
        fontsize=16)
272

273
    for qs_series, ps_series in zip(qs.T, ps.T):
274
      ax.plot(
275
          qs_series,
276
          ps_series,
277
          marker="o",
278
          markersize=2,
279
          linestyle="None",
280
          color="white")
281
      ax.scatter(qs_series[t], ps_series[t], marker="o", s=40, color="white")
282

283
    ax.text(
284
        0,
285
        p_max * 1.7,
286
        r"$p$",
287
        ha="center",
288
        va="center",
289
        size=14,
290
        color="white")
291
    ax.text(
292
        q_max * 1.7,
293
        0,
294
        r"$q$",
295
        ha="center",
296
        va="center",
297
        size=14,
298
        color="white")
299

300
    ax.plot([-q_max * 1.5, q_max * 1.5], [0, 0],
301
            linestyle="dashed",
302
            color="white")
303
    ax.plot([0, 0], [-p_max * 1.5, p_max * 1.5],
304
            linestyle="dashed",
305
            color="white")
306

307
    ax.annotate(
308
        r"$H$ = %0.5f" % hs_formatted[t],
309
        xy=(0, p_max * 2.4),
310
        ha="center",
311
        va="center",
312
        size=14,
313
        color="white")
314

315
    ax.set_xlim(-(q_max * 2), (q_max * 2))
316
    ax.set_ylim(-(p_max * 2.5), (p_max * 2.5))
317

318
    # No ticks
319
    ax.set_xticks([])
320
    ax.set_yticks([])
321

322
  # Construct the animation with the update function as the animation director.
323
  anim = animation.FuncAnimation(
324
      fig, update, frames=num_steps, interval=100, blit=False)
325
  plt.close()
326
  return anim
327

328

329
def static_plot_coordinates_in_phase_space(
330
    positions,
331
    momentums,
332
    title,
333
    fig = None,
334
    ax = None,
335
    max_position = None,
336
    max_momentum = None):
337
  """Plots a static phase space diagram of the given coordinates."""
338
  assert len(positions) == len(momentums)
339

340
  qs, ps = positions, momentums
341
  qs, ps = np.asarray(qs), np.asarray(ps)
342
  if qs.ndim == 1:
343
    qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]
344

345
  assert qs.ndim == 2, f"Got positions of shape {qs.shape}."
346
  assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."
347

348
  if fig is None:
349
    # Create new Figure with black background
350
    fig = plt.figure(figsize=(8, 6), facecolor="black")
351
  else:
352
    fig.set_facecolor("black")
353

354
  if ax is None:
355
    # Add a subplot.
356
    ax = plt.subplot(facecolor="black", frameon=False)
357
  else:
358
    ax.set_facecolor("black")
359
    ax.set_frame_on(False)
360

361
  # Two part titles to get different font weights
362
  fig.text(
363
      x=0.5,
364
      y=0.83,
365
      s=title + " ",
366
      ha="center",
367
      va="bottom",
368
      color="w",
369
      family="sans-serif",
370
      fontweight="light",
371
      fontsize=16)
372
  fig.text(
373
      x=0.5,
374
      y=0.78,
375
      s="PHASE SPACE VISUALIZED",
376
      ha="center",
377
      va="bottom",
378
      color="w",
379
      family="sans-serif",
380
      fontweight="bold",
381
      fontsize=16)
382

383
  for qs_series, ps_series in zip(qs.T, ps.T):
384
    ax.plot(
385
        qs_series,
386
        ps_series,
387
        marker="o",
388
        markersize=2,
389
        linestyle="None",
390
        color="white")
391
    ax.scatter(qs_series[0], ps_series[0], marker="o", s=40, color="white")
392

393
  if max_position is None:
394
    q_max = np.max(np.abs(qs))
395
  else:
396
    q_max = max_position
397

398
  if max_momentum is None:
399
    p_max = np.max(np.abs(ps))
400
  else:
401
    p_max = max_momentum
402

403
  ax.text(
404
      0, p_max * 1.7, r"$p$", ha="center", va="center", size=14, color="white")
405
  ax.text(
406
      q_max * 1.7, 0, r"$q$", ha="center", va="center", size=14, color="white")
407

408
  ax.plot(
409
      [-q_max * 1.5, q_max * 1.5],  # pylint: disable=invalid-unary-operand-type
410
      [0, 0],
411
      linestyle="dashed",
412
      color="white")
413
  ax.plot(
414
      [0, 0],
415
      [-p_max * 1.5, p_max * 1.5],  # pylint: disable=invalid-unary-operand-type
416
      linestyle="dashed",
417
      color="white")
418

419
  ax.set_xlim(-(q_max * 2), (q_max * 2))
420
  ax.set_ylim(-(p_max * 2.5), (p_max * 2.5))
421

422
  # No ticks
423
  ax.set_xticks([])
424
  ax.set_yticks([])
425
  plt.close()
426
  return fig  # pytype: disable=bad-return-type
427

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.