Note
Go to the end to download the full example code.
Training a Symplectic FlashMD Model¶
This tutorial demonstrates how to train a symplectic FlashMD model using the FlashMD framework. Symplectic integrators are designed to preserve the geometric properties of Hamiltonian dynamics, making them particularly suitable for long molecular dynamics simulations. By leveraging symplectic integrators, we can achieve more accurate and stable simulations over extended periods.
import copy
import subprocess
import ase
import ase.build
import ase.io
import ase.units
from ase.calculators.emt import EMT
from ase.md import VelocityVerlet
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
Dataset Creation¶
We will create a dataset of molecular dynamics trajectories using ASE and its built-in EMT potential. The dataset will consist of atomic configurations, forces, and energies obtained from NVE simulations. In reality, you might want to use a more accurate baseline such as ab initio MD or a machine-learned interatomic potential (MLIP).
# We start by creating a simple system (a small box of aluminum).
atoms = ase.build.bulk("Al", "fcc", cubic=True) * (2, 2, 2)
# We first equilibrate the system at 300K using a Langevin thermostat.
MaxwellBoltzmannDistribution(atoms, temperature_K=300)
atoms.calc = EMT()
dyn = Langevin(
atoms, 2 * ase.units.fs, temperature_K=300, friction=1 / (100 * ase.units.fs)
)
dyn.run(1000) # 2 ps equilibration (around 10 ps is better in practice)
# Then, we run a production simulation in the NVE ensemble.
trajectory = []
def store_trajectory():
trajectory.append(copy.deepcopy(atoms))
dyn = VelocityVerlet(atoms, 1 * ase.units.fs)
dyn.attach(store_trajectory, interval=1)
dyn.run(2000) # 2 ps NVE run
Data Preparation¶
Note that the data preparation process is similar to the one in the 04-flashmd.py example, with one key difference. Instead of storing a phase-space point and its future state after one time step, we store the input to the symplectic fixed-point solver. The input is a midpoint that is mapped to the difference in positions and momenta after one time step.
time_lag = 32
spacing = 200
def get_structure_for_dataset(frame_now, frame_ahead):
s = copy.deepcopy(frame_now)
s.arrays["delta_positions"] = (
frame_ahead.get_positions() - frame_now.get_positions()
)
s.arrays["delta_momenta"] = frame_ahead.get_momenta() - frame_now.get_momenta()
s.set_positions(0.5 * (frame_now.get_positions() + frame_ahead.get_positions()))
s.set_momenta(0.5 * (frame_now.get_momenta() + frame_ahead.get_momenta()))
return s
structures_for_dataset = []
for i in range(0, len(trajectory) - time_lag, spacing):
frame_now = trajectory[i]
frame_ahead = trajectory[i + time_lag]
s = get_structure_for_dataset(frame_now, frame_ahead)
structures_for_dataset.append(s)
frame_now_trev = copy.deepcopy(frame_now)
frame_ahead_trev = copy.deepcopy(frame_ahead)
frame_now_trev.set_momenta(-frame_now_trev.get_momenta())
frame_ahead_trev.set_momenta(-frame_ahead_trev.get_momenta())
s = get_structure_for_dataset(frame_ahead_trev, frame_now_trev)
structures_for_dataset.append(s)
ase.io.write("midpoint-to-delta.xyz", structures_for_dataset)
Model Training¶
We can now train a symplectic FlashMD model using the prepared dataset.
For example, you can use the following options file:
seed: 42
base_precision: 32
architecture:
name: experimental.flashmd_symplectic
training:
timestep: 32 # in this case 32 (time lag) * 1 fs (timestep of reference MD)
batch_size: 8 # to be increased in a production scenario
num_epochs: 3 # to be increased (at least 1000-10000) in a production scenario
log_interval: 1
learning_rate: 3e-4
# Note that the scaling weights should be set to the same values to prevent
# inconsistencies in the loss function. Look at the field documentation
# for further details.
fixed_scaling_weights:
positions: 1.0
momenta: 1.0
loss: mse
training_set:
systems:
read_from: midpoint-to-delta.xyz
length_unit: A
targets:
positions:
key: delta_positions
quantity: length
unit: A
type:
cartesian:
rank: 1
per_atom: true
momenta:
key: delta_momenta
quantity: momentum
unit: (eV*u)^(1/2)
type:
cartesian:
rank: 1
per_atom: true
validation_set: 0.1
test_set: 0.0
subprocess.run(["mtt", "train", "options-flashmd-symplectic.yaml"], check=True)
CompletedProcess(args=['mtt', 'train', 'options-flashmd-symplectic.yaml'], returncode=0)
Total running time of the script: (0 minutes 34.150 seconds)