google-research
426 строк · 11.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Simulation of coupled harmonic motion."""
17
18from typing import Mapping, Optional, Sequence, Tuple19
20import chex21import jax22import jax.numpy as jnp23from matplotlib import animation24import matplotlib.pyplot as plt25import numpy as np26
27
28def compute_normal_modes(29simulation_parameters
30):31"""Returns the angular frequencies and eigenvectors for the normal modes."""32m, k_wall, k_pair = (simulation_parameters["m"],33simulation_parameters["k_wall"],34simulation_parameters["k_pair"])35num_trajectories = m.shape[0]36
37# Construct coupling matrix.38coupling_matrix = (-(k_wall + 2 * k_pair) * jnp.eye(num_trajectories) +39k_pair * jnp.ones((num_trajectories, num_trajectories)))40coupling_matrix = jnp.diag(1 / m) @ coupling_matrix41
42# Compute eigenvalues and eigenvectors.43eigvals, eigvecs = jnp.linalg.eig(coupling_matrix)44w = jnp.sqrt(-eigvals)45w = jnp.real(w)46eigvecs = jnp.real(eigvecs)47return w, eigvecs48
49
50def generate_canonical_coordinates(51t, simulation_parameters52):53"""Returns q (position) and p (momentum) coordinates at instant t."""54w, eigvecs = compute_normal_modes(simulation_parameters)55m = simulation_parameters["m"]56normal_mode_simulation_parameters = {57"A": simulation_parameters["A"],58"phi": simulation_parameters["phi"],59# We will scale momentums by mass later.60"m": jnp.ones_like(m),61"w": w,62}63normal_mode_trajectories = generate_canonical_coordinates_for_normal_mode(64t, normal_mode_simulation_parameters)65trajectories = jax.tree_map(lambda arr: eigvecs @ arr,66normal_mode_trajectories)67positions, momentums = trajectories68# Scale momentums by mass here.69momentums = momentums * m70return positions, momentums71
72
73def generate_canonical_coordinates_for_normal_mode(74t,75mode_simulation_parameters,76):77"""Returns q (position) and p (momentum) coordinates at instant t."""78phi, a, m, w = (mode_simulation_parameters["phi"],79mode_simulation_parameters["A"],80mode_simulation_parameters["m"],81mode_simulation_parameters["w"])82position = a * jnp.cos(w * t + phi)83momentum = -m * w * a * jnp.sin(w * t + phi)84return position, momentum85
86
87def _squared_l2_distance(u, v):88return jnp.square(u - v).sum()89
90
91def compute_hamiltonian(92position,93momentum,94simulation_parameters,95):96"""Computes the Hamiltonian at the given coordinates."""97m, k_wall, k_pair = (simulation_parameters["m"],98simulation_parameters["k_wall"],99simulation_parameters["k_pair"][0])100q, p = position, momentum101squared_distance_matrix = jax.vmap(102jax.vmap(_squared_l2_distance, in_axes=(None, 0)), in_axes=(0, None)103)(q, q)104squared_distances = jnp.sum(squared_distance_matrix) / 2105hamiltonian = ((p**2) / (2 * m)).sum()106hamiltonian += (k_wall * (q**2)).sum() / 2107hamiltonian += (k_pair * squared_distances) / 2108return hamiltonian109
110
111def plot_coordinates(positions, momentums,112simulation_parameters,113title):114"""Plots coordinates in the canonical basis."""115assert len(positions) == len(momentums)116
117qs, ps = positions, momentums118qs, ps = np.asarray(qs), np.asarray(ps)119if qs.ndim == 1:120qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]121
122assert qs.ndim == 2, f"Got positions of shape {qs.shape}."123assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."124
125# Create new Figure with black background126fig = plt.figure(figsize=(8, 6), facecolor="black")127
128# Add a subplot with no frame129ax = plt.subplot(frameon=False)130
131# Compute Hamiltonians.132num_steps = qs.shape[0]133q_max = np.max(np.abs(qs))134p_max = np.max(np.abs(ps))135p_scale = (q_max / p_max) / 5136hs = jax.vmap( # pytype: disable=wrong-arg-types # numpy-scalars137compute_hamiltonian, in_axes=(0, 0, None))(qs, ps, simulation_parameters)138hs_formatted = np.round(hs.squeeze(), 5)139
140def update(t):141# Update data142ax.clear()143
144# 2 part titles to get different font weights145ax.text(1460.5,1471.0,148title + " ",149transform=ax.transAxes,150ha="center",151va="bottom",152color="w",153family="sans-serif",154fontweight="light",155fontsize=16)156ax.text(1570.5,1580.93,159"VISUALIZED",160transform=ax.transAxes,161ha="center",162va="bottom",163color="w",164family="sans-serif",165fontweight="bold",166fontsize=16)167
168for qs_series, ps_series in zip(qs.T, ps.T):169ax.scatter(qs_series[t], 10, marker="o", s=40, color="white")170ax.annotate(171r"$q$",172xy=(qs_series[t], 8),173ha="center",174va="center",175size=12,176color="white")177ax.annotate(178r"$p$",179xy=(qs_series[t], 10 - 0.15),180xytext=(qs_series[t] + ps_series[t] * p_scale, 10 - 0.15),181arrowprops=dict(arrowstyle="<-", color="white"),182ha="center",183va="center",184size=12,185color="white")186
187ax.plot([0, 0], [5, 15], linestyle="dashed", color="white")188
189ax.annotate(190r"$H$ = %0.5f" % hs_formatted[t],191xy=(0, 40),192ha="center",193va="center",194size=14,195color="white")196
197ax.set_xlim(-(q_max * 1.1), (q_max * 1.1))198ax.set_ylim(-1, 50)199
200# No ticks201ax.set_xticks([])202ax.set_yticks([])203
204# Construct the animation with the update function as the animation director.205anim = animation.FuncAnimation(206fig, update, frames=num_steps, interval=100, blit=False)207plt.close()208return anim209
210
211def plot_coordinates_in_phase_space(212positions,213momentums,214simulation_parameters,215title,216):217"""Plots a phase space diagram of the given coordinates."""218assert len(positions) == len(momentums)219
220qs, ps = positions, momentums221qs, ps = np.asarray(qs), np.asarray(ps)222if qs.ndim == 1:223qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]224
225assert qs.ndim == 2, f"Got positions of shape {qs.shape}."226assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."227
228# Create new Figure with black background229fig = plt.figure(figsize=(8, 6), facecolor="black")230
231# Add a subplot.232ax = plt.subplot(facecolor="black")233pos = ax.get_position()234pos = [pos.x0, pos.y0 - 0.15, pos.width, pos.height]235ax.set_position(pos)236
237# Compute Hamiltonians.238num_steps = qs.shape[0]239q_max = np.max(np.abs(qs))240p_max = np.max(np.abs(ps))241hs = jax.vmap( # pytype: disable=wrong-arg-types # numpy-scalars242compute_hamiltonian, in_axes=(0, 0, None))(qs, ps, simulation_parameters)243hs_formatted = np.round(hs.squeeze(), 5)244
245def update(t):246# Update data247ax.clear()248
249# 2 part titles to get different font weights250ax.text(2510.5,2520.83,253title + " ",254transform=fig.transFigure,255ha="center",256va="bottom",257color="w",258family="sans-serif",259fontweight="light",260fontsize=16)261ax.text(2620.5,2630.78,264"PHASE SPACE VISUALIZED",265transform=fig.transFigure,266ha="center",267va="bottom",268color="w",269family="sans-serif",270fontweight="bold",271fontsize=16)272
273for qs_series, ps_series in zip(qs.T, ps.T):274ax.plot(275qs_series,276ps_series,277marker="o",278markersize=2,279linestyle="None",280color="white")281ax.scatter(qs_series[t], ps_series[t], marker="o", s=40, color="white")282
283ax.text(2840,285p_max * 1.7,286r"$p$",287ha="center",288va="center",289size=14,290color="white")291ax.text(292q_max * 1.7,2930,294r"$q$",295ha="center",296va="center",297size=14,298color="white")299
300ax.plot([-q_max * 1.5, q_max * 1.5], [0, 0],301linestyle="dashed",302color="white")303ax.plot([0, 0], [-p_max * 1.5, p_max * 1.5],304linestyle="dashed",305color="white")306
307ax.annotate(308r"$H$ = %0.5f" % hs_formatted[t],309xy=(0, p_max * 2.4),310ha="center",311va="center",312size=14,313color="white")314
315ax.set_xlim(-(q_max * 2), (q_max * 2))316ax.set_ylim(-(p_max * 2.5), (p_max * 2.5))317
318# No ticks319ax.set_xticks([])320ax.set_yticks([])321
322# Construct the animation with the update function as the animation director.323anim = animation.FuncAnimation(324fig, update, frames=num_steps, interval=100, blit=False)325plt.close()326return anim327
328
329def static_plot_coordinates_in_phase_space(330positions,331momentums,332title,333fig = None,334ax = None,335max_position = None,336max_momentum = None):337"""Plots a static phase space diagram of the given coordinates."""338assert len(positions) == len(momentums)339
340qs, ps = positions, momentums341qs, ps = np.asarray(qs), np.asarray(ps)342if qs.ndim == 1:343qs, ps = qs[Ellipsis, np.newaxis], ps[Ellipsis, np.newaxis]344
345assert qs.ndim == 2, f"Got positions of shape {qs.shape}."346assert ps.ndim == 2, f"Got momentums of shape {ps.shape}."347
348if fig is None:349# Create new Figure with black background350fig = plt.figure(figsize=(8, 6), facecolor="black")351else:352fig.set_facecolor("black")353
354if ax is None:355# Add a subplot.356ax = plt.subplot(facecolor="black", frameon=False)357else:358ax.set_facecolor("black")359ax.set_frame_on(False)360
361# Two part titles to get different font weights362fig.text(363x=0.5,364y=0.83,365s=title + " ",366ha="center",367va="bottom",368color="w",369family="sans-serif",370fontweight="light",371fontsize=16)372fig.text(373x=0.5,374y=0.78,375s="PHASE SPACE VISUALIZED",376ha="center",377va="bottom",378color="w",379family="sans-serif",380fontweight="bold",381fontsize=16)382
383for qs_series, ps_series in zip(qs.T, ps.T):384ax.plot(385qs_series,386ps_series,387marker="o",388markersize=2,389linestyle="None",390color="white")391ax.scatter(qs_series[0], ps_series[0], marker="o", s=40, color="white")392
393if max_position is None:394q_max = np.max(np.abs(qs))395else:396q_max = max_position397
398if max_momentum is None:399p_max = np.max(np.abs(ps))400else:401p_max = max_momentum402
403ax.text(4040, p_max * 1.7, r"$p$", ha="center", va="center", size=14, color="white")405ax.text(406q_max * 1.7, 0, r"$q$", ha="center", va="center", size=14, color="white")407
408ax.plot(409[-q_max * 1.5, q_max * 1.5], # pylint: disable=invalid-unary-operand-type410[0, 0],411linestyle="dashed",412color="white")413ax.plot(414[0, 0],415[-p_max * 1.5, p_max * 1.5], # pylint: disable=invalid-unary-operand-type416linestyle="dashed",417color="white")418
419ax.set_xlim(-(q_max * 2), (q_max * 2))420ax.set_ylim(-(p_max * 2.5), (p_max * 2.5))421
422# No ticks423ax.set_xticks([])424ax.set_yticks([])425plt.close()426return fig # pytype: disable=bad-return-type427