Source code for openferro.integrator.md

"""
Integrators for unconstrained molecular dynamics.

This file is part of OpenFerro.

"""

import jax
from jax import jit
import jax.numpy as jnp
from openferro.units import Constants
from openferro.integrator.base import Integrator


[docs] class GradientDescentIntegrator(Integrator): """ Gradient descent integrator. Parameters ---------- dt : float Time step size """ def _step_x(self, x, f, m, dt): return x + f / m * dt
[docs] def __init__(self, dt): super().__init__(dt) self.step_x = jit(self._step_x)
[docs] def step(self, field): """ Update the field by one time step. Parameters ---------- field : Field The field to be updated Returns ------- Field The updated field """ x0 = field.get_values() f0 = field.get_force() x0 = self.step_x(x0, f0, field.get_mass(), self.dt) field.set_values(x0) return field
[docs] class GradientDescentIntegrator_Strain(GradientDescentIntegrator): """ Gradient descent integrator for global strain. Parameters ---------- dt : float Time step size freeze_x : bool, optional Whether to freeze motion in x direction freeze_y : bool, optional Whether to freeze motion in y direction freeze_z : bool, optional Whether to freeze motion in z direction """
[docs] def __init__(self, dt, freeze_x=False, freeze_y=False, freeze_z=False): super().__init__(dt) if (not freeze_x) and (not freeze_y) and (not freeze_z): self.mask = jnp.ones((6,)) else: self.mask = jnp.array([int(not freeze_x), int(not freeze_y), int(not freeze_z), 0, 0, 0]) def _step_x(x, f, m, dt): return x + f / m * dt * self.mask self.step_x = jit(_step_x)
[docs] class LeapFrogIntegrator(Integrator): """ Leapfrog integrator. Parameters ---------- dt : float Time step size """ def _step_xp(self, x, v, f, m, dt): v += f / m * dt x += v * dt return x, v
[docs] def __init__(self, dt): super().__init__(dt) self.step_xp = jit(self._step_xp)
[docs] def step(self, field): """ Update the field by one time step. Parameters ---------- field : Field The field to be updated Returns ------- Field The updated field """ x0 = field.get_values() v0 = field.get_velocity() x0, v0 = self.step_xp(x0, v0, field.get_force(), field.get_mass(), self.dt) field.set_values(x0) field.set_velocity(v0) return field
[docs] class LeapFrogIntegrator_Strain(LeapFrogIntegrator): """ Leapfrog integrator for global strain. Parameters ---------- dt : float Time step size freeze_x : bool, optional Whether to freeze motion in x direction freeze_y : bool, optional Whether to freeze motion in y direction freeze_z : bool, optional Whether to freeze motion in z direction """
[docs] def __init__(self, dt, freeze_x=False, freeze_y=False, freeze_z=False): super().__init__(dt) if (not freeze_x) and (not freeze_y) and (not freeze_z): self.mask = jnp.ones((6,)) else: self.mask = jnp.array([int(not freeze_x), int(not freeze_y), int(not freeze_z), 0, 0, 0]) def _step_xp(x, v, f, m, dt): v += f / m * dt v *= self.mask x += v * dt return x, v self.step_xp = jit(_step_xp)
[docs] class LangevinIntegrator(Integrator): """ Langevin integrator as in J. Phys. Chem. A 2019, 123, 28, 6056-6079. ABOBA scheme: exp(i L dt) = exp(i Lx dt/2)exp(i Lt dt)exp(i Lx dt/2)exp(i Lp dt) Lx/2: half-step position update (_step_x) Lt: velocity update from damping and noise (_step_t) Lp: full-step velocity update (_step_p) Parameters ---------- dt : float Time step size temp : float Temperature tau : float Relaxation time """ def _step_p(self, v, f, m, dt): v += f / m * dt return v def _step_x(self, x, v, dt): x += 0.5 * v * dt return x def _step_t(self, v, noise, z1, z2): v = z1 * v + z2 * noise return v
[docs] def __init__(self, dt, temp, tau): super().__init__(dt) self.temp = temp self.kbT = Constants.kb * temp self.tau = tau self.gamma = 1.0 / tau self.z1 = jnp.exp(-dt * self.gamma) self.z2 = jnp.sqrt(1 - jnp.exp(-2 * dt * self.gamma)) self.step_p = jit(self._step_p) self.step_x = jit(self._step_x) self.step_t = jit(self._step_t)
[docs] def get_noise(self, key, field): """ Generate random noise for the Langevin dynamics. Parameters ---------- key : jax.random.PRNGKey Random number generator key field : Field The field to generate noise for Returns ------- jax.Array Random noise array """ gaussian = jax.random.normal(key, field.get_velocity().shape) if field._sharding != gaussian.sharding: gaussian = jax.device_put(gaussian, field._sharding) return gaussian
[docs] def step(self, key, field): """ Update the field by one time step. Parameters ---------- key : jax.random.PRNGKey Random number generator key field : Field The field to be updated Returns ------- Field The updated field """ dt = self.dt mass = field.get_mass() force = field.get_force() v0 = field.get_velocity() x0 = field.get_values() v0 = self.step_p(v0, force, mass, dt) x0 = self.step_x(x0, v0, dt) gaussian = self.get_noise(key, field) gaussian *= (self.kbT/ mass)**0.5 v0 = self.step_t(v0, gaussian, self.z1, self.z2) x0 = self.step_x(x0, v0, dt) field.set_values(x0) field.set_velocity(v0) return field
[docs] class LangevinIntegrator_Strain(LangevinIntegrator): """ Langevin integrator for global strain. Parameters ---------- dt : float Time step size temp : float Temperature tau : float Relaxation time freeze_x : bool, optional Whether to freeze motion in x direction freeze_y : bool, optional Whether to freeze motion in y direction freeze_z : bool, optional Whether to freeze motion in z direction """
[docs] def __init__(self, dt, temp, tau, freeze_x=False, freeze_y=False, freeze_z=False): super().__init__(dt, temp, tau) if (not freeze_x) and (not freeze_y) and (not freeze_z): self.mask = jnp.ones((6,)) else: self.mask = jnp.array([int(not freeze_x), int(not freeze_y), int(not freeze_z), 0, 0, 0]) def _step_p(v, f, m, dt): v += f / m * dt v *= self.mask return v def _step_x(x, v, dt): x += 0.5 * v * dt return x def _step_t(v, noise, z1, z2): v = z1 * v + z2 * noise v *= self.mask return v self.step_p = jit(_step_p) self.step_x = jit(_step_x) self.step_t = jit(_step_t)
[docs] class OverdampedLangevinIntegrator(Integrator): """ Overdamped Langevin integrator. Parameters ---------- dt : float Time step size temp : float Temperature tau : float Relaxation time """
[docs] def __init__(self, dt, temp, tau): super().__init__(dt) raise NotImplementedError("Overdamped Langevin integrator is not implemented yet.")