Source code for openferro.simulation

"""
Classes which define the time evolution of physical systems.

Notes
-----
This file is part of OpenFerro.
"""

from time import time as timer
import logging
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
from openferro.units import Constants
from openferro.field import GlobalStrain
from openferro.reporter import Thermo_Reporter, Field_Reporter

[docs] class Simulation: """ The base class to define a simulation. A simulation controls the time evolution of a system. This class does not implement any time integration algorithm. Each field object has its own integrator. The class only calls the step method of each integrator and controls output. Parameters ---------- system : System The physical system to simulate """
[docs] def __init__(self, system): self.system = system self.reporters = []
[docs] def clear_reporters(self): self.reporters = []
[docs] def add_thermo_reporter(self, file='thermo.log', log_interval=100, global_strain=False, excess_stress=False, volume=False, potential_energy=False, kinetic_energy=False, temperature=False): """ Add a reporter to output global thermodynamic information. Parameters ---------- file : str, optional Output file name, by default 'thermo.log' log_interval : int, optional Number of steps between outputs, by default 100 global_strain : bool, optional Whether to output global strain, by default False excess_stress : bool, optional Whether to output excess stress, by default False volume : bool, optional Whether to output volume, by default False potential_energy : bool, optional Whether to output potential energy, by default False kinetic_energy : bool, optional Whether to output kinetic energy, by default False temperature : bool, optional Whether to output temperature, by default False """ self.reporters.append(Thermo_Reporter(file, log_interval, global_strain, excess_stress, volume, potential_energy, kinetic_energy, temperature))
[docs] def add_field_reporter(self, file_prefix, field_ID, log_interval=100, field_average=True, dump_field=False): """ Add a reporter to dump the values of a given field. Parameters ---------- file_prefix : str Prefix for output files field_ID : str ID of field to report log_interval : int, optional Number of steps between outputs, by default 100 field_average : bool, optional Whether to output field averages, by default True dump_field : bool, optional Whether to dump full field values, by default False """ self.reporters.append(Field_Reporter(file_prefix, field_ID, log_interval, field_average, dump_field))
[docs] def initialize_reporters(self): """Initialize all reporters.""" for reporter in self.reporters: reporter.initialize(self.system)
[docs] def remove_all_reporters(self): self.reporters = []
[docs] def reset_reporters(self): """Reset the counters of all reporters.""" for reporter in self.reporters: reporter.counter = -1
[docs] def step_reporters(self): """Step all reporters.""" for reporter in self.reporters: reporter.step(self.system)
[docs] def init_velocity(self, mode='zero', temp=None): for field in self.all_fields: field.init_velocity(mode=mode, temperature=temp)
[docs] def _step(self): """ Update the system by one time step. To be implemented by subclasses. """ pass
[docs] def run(self): """ Run the simulation for a given number of steps or until convergence. To be implemented by subclasses. """ pass
[docs] class MDMinimize(Simulation): """ Class for energy minimization using molecular dynamics. Parameters ---------- system : System The physical system to minimize max_iter : int, optional Maximum number of iterations, by default 100 tol : float, optional Force tolerance for convergence, by default 1e-5 """
[docs] def __init__(self, system, max_iter=100, tol=1e-5 ): super().__init__(system) self.max_iter = max_iter self.tol = tol self.all_fields = self.system.get_all_fields()
[docs] def _step(self, variable_cell): """ Update the field by one time step. Parameters ---------- variable_cell : bool Whether to allow cell parameters to vary """ SO3_fields = self.system.get_all_SO3_fields() non_SO3_fields = self.system.get_all_non_SO3_fields() if len(non_SO3_fields) > 0: ## update the force for all fields. ## Force will not be updated again while integrating each non-SO3 field with simple explicit integrator. self.system.update_force() for field in non_SO3_fields: if (variable_cell is False) and isinstance(field, GlobalStrain): continue field.integrator.step(field) if len(SO3_fields) > 0: ## Force updater will be passed to the integrator of each SO3 fields because implicit methods are used. ## So the force will not be updated here. for field in SO3_fields: field.integrator.step(field, force_updater=self.system.update_force)
[docs] def run(self, variable_cell=True, pressure=None): """ Run the minimization. Parameters ---------- variable_cell : bool, optional Whether to allow cell parameters to vary, by default True pressure : float, optional External pressure in bar, required if variable_cell=True Raises ------ ValueError If pressure not specified for variable cell minimization If pressure specified for fixed cell minimization If integrator not set for any field """ ## sanity check if variable_cell: if pressure is None: raise ValueError('Please specify pressure for variable-cell structural minimization') else: # self.system.get_interaction_by_ID('pV').set_parameter_by_ID( # 'p', pressure * Constants.bar) # bar -> eV/Angstrom^3 pV_param = self.system.get_interaction_by_ID('pV').get_parameters() pV_param_new = [pressure * Constants.bar, pV_param[1]] self.system.get_interaction_by_ID('pV').set_parameters(pV_param_new) for field in self.all_fields: if field.integrator is None: raise ValueError('Please set the integrator for the field %s for variable-cell structural minimization' % type(field)) else: if pressure is not None: raise ValueError('Specifying pressure is not allowed for fixed-cell structural minimization') for field in [field for field in self.all_fields if not isinstance(field, GlobalStrain)]: if field.integrator is None: raise ValueError('Please set the integrator for the field %s for fixed-cell structural minimization' % type(field)) ## structural relaxation self.initialize_reporters() for i in range(self.max_iter): self._step(variable_cell) self.step_reporters() converged = [] for field in self.all_fields: if jnp.max(jnp.abs(field.get_force())) < self.tol: converged.append(True) else: converged.append(False) if all(converged): break
[docs] class SimulationNVE(Simulation): """ Class for NVE (microcanonical) molecular dynamics simulation. Parameters ---------- system : System The physical system to simulate """
[docs] def __init__(self, system): super().__init__(system) ## get all fields, excluding the global strain field self.SO3_fields = self.system.get_all_SO3_fields() self.non_SO3_fields = [field for field in self.system.get_all_non_SO3_fields() if not isinstance(field, GlobalStrain)] self.all_fields = self.SO3_fields + self.non_SO3_fields self.nfields = len(self.all_fields)
[docs] def _step(self, profile=False): """ Update the field by one step. Parameters ---------- profile : bool, optional Whether to profile timing, by default False """ if len(self.non_SO3_fields) > 0: ## update the force for all fields. ## Force will not be updated again while integrating each non-SO3 field with simple explicit integrator. self.system.update_force(profile=profile) for field in self.non_SO3_fields: if profile: t0 = timer() field.integrator.step(field) if profile: jax.block_until_ready(field.get_values()) logging.info('Time for updating field {}: {:.8f}s'.format(type(field), timer()-t0)) if len(self.SO3_fields) > 0: ## Force updater will be passed to the integrator of each SO3 fields because implicit methods are used. ## So the force will not be updated here. for field in self.SO3_fields: if profile: t0 = timer() field.integrator.step(field, force_updater=self.system.update_force) if profile: jax.block_until_ready(field.get_values()) logging.info('Time for updating field {}: {:.8f}s'.format(type(field), timer()-t0))
[docs] def run(self, nsteps=1, profile=False): """ Run the simulation. Parameters ---------- nsteps : int, optional Number of steps to run, by default 1 profile : bool, optional Whether to profile timing, by default False Raises ------ ValueError If integrator not set for any field """ ## sanity check for field in self.all_fields: if field.integrator is None: raise ValueError('Please set the integrator for the field %s before running the simulation' % type(field)) ## run the simulation self.initialize_reporters() for i in range(nsteps): self._step(profile=profile) self.step_reporters()
[docs] class SimulationNVTLangevin(SimulationNVE): """ Class for NVT molecular dynamics simulation using Langevin dynamics. Parameters ---------- system : System The physical system to simulate """
[docs] def __init__(self, system): super().__init__(system)
[docs] def _step(self, keys, profile=False): """ Update the field by one step. Parameters ---------- keys : array_like Random keys for Langevin dynamics profile : bool, optional Whether to profile timing, by default False """ keys_SO3 = keys[:len(self.SO3_fields)] keys_non_SO3 = keys[len(self.SO3_fields):] if len(self.non_SO3_fields) > 0: self.system.update_force(profile=profile) for field, subkey in zip(self.non_SO3_fields, keys_non_SO3): if profile: t0 = timer() field.integrator.step(subkey, field) if profile: jax.block_until_ready(field.get_values()) logging.info('Time for updating field {}: {:.8f}s'.format(type(field), timer()-t0)) if len(self.SO3_fields) > 0: for field, subkey in zip(self.SO3_fields, keys_SO3): if profile: t0 = timer() field.integrator.step(subkey, field, force_updater=self.system.update_force) if profile: jax.block_until_ready(field.get_values()) logging.info('Time for updating field {}: {:.8f}s'.format(type(field), timer()-t0)) return
[docs] def run(self, nsteps=1, profile=False): """ Run the simulation. Parameters ---------- nsteps : int, optional Number of steps to run, by default 1 profile : bool, optional Whether to profile timing, by default False Raises ------ ValueError If integrator not set for any field """ ## sanity check for field in self.all_fields: if field.integrator is None: raise ValueError('Please set the integrator for the field %s before running the simulation' % type(field)) ## generate all the needed random keys in advance key = jax.random.PRNGKey(np.random.randint(0, 1000000)) keys = jax.random.split(key, nsteps * self.nfields) ## run the simulation self.initialize_reporters() for id_step in range(nsteps): if profile: t0 = timer() subkeys = keys[id_step * self.nfields:(id_step+1) * self.nfields] self._step(subkeys, profile) self.step_reporters() if profile: logging.info('Total time for one step: {:.8f}s'.format(timer()-t0)) return
[docs] class SimulationNPTLangevin(SimulationNVTLangevin): """ Class for NPT molecular dynamics simulation using Langevin dynamics. Parameters ---------- system : System The physical system to simulate pressure : float, optional External pressure in bar, by default 0.0 """
[docs] def __init__(self, system, pressure=0.0): super().__init__(system) ## set pressure self.pressure = pressure pV_param = self.system.get_interaction_by_ID('pV').get_parameters() pV_param_new = [pressure * Constants.bar, pV_param[1]] self.system.get_interaction_by_ID('pV').set_parameters(pV_param_new) ## get all fields, including the global strain field self.SO3_fields = self.system.get_all_SO3_fields() self.non_SO3_fields = self.system.get_all_non_SO3_fields() self.all_fields = self.SO3_fields + self.non_SO3_fields self.nfields = len(self.all_fields)