"""
Classes which specify the physical system.
Notes
-----
This file is part of OpenFerro.
"""
from time import time as timer
import logging
import numpy as np
import jax.numpy as jnp
from openferro.field import *
from openferro.interaction import *
from openferro.units import Constants
## import force engines
from openferro.engine.elastic import *
from openferro.engine.ferroelectric import *
from openferro.engine.magnetic import *
from openferro.engine.ewald import get_dipole_dipole_ewald
## import parallelism modules
from openferro.parallelism import DeviceMesh
[docs]
class System:
"""
A class to define a physical system.
A system is a lattice with fields and a Hamiltonian. Fields are added to the system by the user.
Interactions are added to the system by the user. Pointers to fields and interactions are stored in dictionaries.
Attributes
----------
lattice : BravaisLattice3D
The lattice of the system
"""
[docs]
def __init__(self, lattice ):
self.lattice = lattice
self.pbc = lattice.pbc
self._fields_dict = {}
self._self_interaction_dict = {}
self._mutual_interaction_dict = {}
self._triple_interaction_dict = {}
def __repr__(self):
return f"System with lattice {self.lattice} and fields {self._fields_dict.keys()}"
"""
Methods for fields
"""
[docs]
def get_field_by_ID(self, ID):
"""
Get a field by ID.
Parameters
----------
ID : str
ID of the field
Returns
-------
Field
The field with the given ID
Raises
------
ValueError
If field with given ID does not exist
"""
if ID in self._fields_dict:
return self._fields_dict[ID]
else:
raise ValueError('Field with the ID {} does not exist.'.format(ID))
[docs]
def get_all_fields(self):
"""
Get all fields in the system.
Returns
-------
list
All fields in the system
"""
return [self._fields_dict[ID] for ID in self._fields_dict.keys()]
[docs]
def get_all_SO3_fields(self):
return [field for field in self.get_all_fields() if isinstance(field, FieldSO3)]
[docs]
def get_all_non_SO3_fields(self):
return [field for field in self.get_all_fields() if not isinstance(field, FieldSO3)]
[docs]
def move_fields_to_multi_devs(self, mesh: DeviceMesh):
"""
Move all fields to given devices for parallelization.
Parameters
----------
mesh : DeviceMesh
The device mesh to move the fields to
"""
for ID in self._fields_dict.keys():
self._fields_dict[ID].to_multi_devs(mesh)
[docs]
def add_field(self, ID, ftype='scalar', dim=None, value=None, mass=1.0):
"""
Add a predefined field to the system.
Parameters
----------
ID : str
ID of the field
ftype : str, optional
Type of the field. Can be 'scalar', 'SO3', 'LocalStrain3D', etc
dim : int, optional
Dimension of the field. Only used for Rn fields
value : array-like, optional
Initial value of the field. Will be broadcasted to the shape of the field
mass : float or array-like, optional
Mass of the field. When mass is a float, it will be broadcasted to the shape of the field
Returns
-------
Field
The field with the given ID
Raises
------
ValueError
If field with this ID already exists or if ID is reserved
ValueError
If field type is unknown
"""
## sanity check
if ID in self._fields_dict:
raise ValueError("Field with this ID already exists. Pick another ID")
if ID == 'gstrain':
raise ValueError("The ID 'gstrain' is reserved for global strain field. Please pick another ID.")
## add field
if ftype == 'Rn':
init = jnp.array(value) if value is not None else jnp.zeros(dim)
self._fields_dict[ID] = FieldRn(self.lattice, ID, dim)
self._fields_dict[ID].set_values(jnp.zeros((self.lattice.size[0], self.lattice.size[1], self.lattice.size[2], dim)) + init)
elif ftype == 'SO3':
init = jnp.array(value) if value is not None else jnp.array([0,0,1.0])
self._fields_dict[ID] = FieldSO3(self.lattice, ID)
self._fields_dict[ID].set_values(jnp.zeros((self.lattice.size[0], self.lattice.size[1], self.lattice.size[2], 3)) + init)
elif ftype == 'LocalStrain3D':
init = jnp.array(value) if value is not None else jnp.zeros(3)
self._fields_dict[ID] = LocalStrain3D(self.lattice, ID)
self._fields_dict[ID].set_values(jnp.zeros((self.lattice.size[0], self.lattice.size[1], self.lattice.size[2], 3)) + init)
else:
raise ValueError("Unknown field type. ")
if mass is not None:
self._fields_dict[ID].set_mass(mass)
return self._fields_dict[ID]
[docs]
def add_global_strain(self, value=None, mass=1):
"""
Add a global strain to the system. Allow variable cell simulation.
Parameters
----------
value : array-like, optional
Initial value of the global strain
mass : float, optional
Effective mass of the global strain for the barostat
Returns
-------
Field
The global strain field
Raises
------
AssertionError
If value is provided but not a 6D vector
"""
ID = 'gstrain'
if value is not None:
assert len(value) == 6, "Global strain must be a 6D vector"
init = jnp.array(value)
else:
init = jnp.zeros(6)
self._fields_dict[ID] = GlobalStrain(self.lattice, ID)
self._fields_dict[ID].set_values(jnp.zeros(( 6)) + init)
self.add_pressure(0.0)
self._fields_dict[ID].set_mass(mass)
return self._fields_dict[ID]
"""
Methods for interactions
"""
@property
def interaction_dict(self):
"""
Get all interactions in the system.
Returns
-------
dict
All interactions in the system
"""
return {**self._self_interaction_dict, **self._mutual_interaction_dict, **self._triple_interaction_dict}
[docs]
def get_interaction_by_ID(self, interaction_ID):
"""
Get an interaction by ID.
Parameters
----------
interaction_ID : str
ID of the interaction
Returns
-------
Interaction
The interaction with the given ID
Raises
------
ValueError
If interaction with given ID does not exist
"""
if interaction_ID in self._self_interaction_dict:
return self._self_interaction_dict[interaction_ID]
elif interaction_ID in self._mutual_interaction_dict:
return self._mutual_interaction_dict[interaction_ID]
else:
raise ValueError("Interaction with ID {} does not exist. Existing interactions: {} {}".format(interaction_ID,
self._self_interaction_dict.keys(), self._mutual_interaction_dict.keys()))
[docs]
def _add_interaction_sanity_check(self, ID):
"""
Sanity check for adding an interaction.
Parameters
----------
ID : str
ID of the interaction to check
Raises
------
ValueError
If ID is reserved or already exists
"""
if ID == 'pV':
raise ValueError("The interaction ID 'pV' is internal. The term pV in the Hamiltonian will be added automatically when you add a global strain.")
if ID in self.interaction_dict:
raise ValueError("Interaction with ID {} already exists. Pick another ID.".format(ID))
return
"""
Methods for adding pre-defined interactions to the Hamiltonian
"""
## electric dipole-type interactions
[docs]
def add_dipole_dipole_interaction(self, ID, field_ID, prefactor=1.0, enable_jit=True):
"""
Add the long-range dipole-dipole interaction term to the Hamiltonian.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
prefactor : float, optional
Prefactor of the interaction
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
field = self.get_field_by_ID(field_ID)
interaction = self_interaction( field_ID)
energy_engine = get_dipole_dipole_ewald(field.lattice, sharding=field._sharding)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([prefactor]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_dipole_onsite_interaction(self, ID, field_ID, K2, alpha, gamma, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(self_energy_onsite_isotropic, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([K2, alpha, gamma]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_dipole_interaction_1st_shell(self, ID, field_ID, j1, j2, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(short_range_1stnn_isotropic, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([j1, j2]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_dipole_interaction_2nd_shell(self, ID, field_ID, j3, j4, j5, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(short_range_2ednn_isotropic, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([j3, j4, j5]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_dipole_interaction_3rd_shell(self, ID, field_ID, j6, j7, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
energy_engine = get_short_range_3rdnn_isotropic()
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([j6, j7]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_dipole_efield_interaction(self, ID, field_ID, E, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(dipole_efield_interaction, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array(E))
self._self_interaction_dict[ID] = interaction
return interaction
## elastic-type interactions
[docs]
def add_homo_elastic_interaction(self, ID, field_ID, B11, B12, B44, enable_jit=True):
N = float(self.lattice.nsites)
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(homo_elastic_energy, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([B11, B12, B44, N]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_inhomo_elastic_interaction(self, ID, field_ID, B11, B12, B44, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(inhomo_elastic_energy, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([B11, B12, B44]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_pressure(self, pressure):
"""
Add a pressure term (pV) to the Hamiltonian.
The ID of the interaction is reserved as 'pV'. V is the volume of the system,
which is calculated from the reference lattice vectors and the global strain.
Parameters
----------
pressure : float
Pressure in bars
Returns
-------
Interaction
The pV interaction
Raises
------
ValueError
If pV term already exists or if gstrain field is invalid
"""
_pres = pressure * Constants.bar # bar -> eV/Angstrom^3
## interaction ID sanity check
ID = 'pV'
if ID in self.interaction_dict:
raise ValueError("pV term already exists in the Hamiltonian.")
## field ID sanity check
field_ID = 'gstrain'
field = self.get_field_by_ID(field_ID)
if not isinstance(field, GlobalStrain):
raise ValueError("I find a field with ID <gstrain>, but it is not a global strain field. Please rename the field or remove it. Then add the global strain field to the system.")
## add the interaction
interaction = self_interaction(field_ID)
interaction.set_energy_engine(energy_engine=pV_energy, enable_jit=True)
interaction.create_force_engine(enable_jit=True)
parameters = jnp.array([_pres, self.lattice.ref_volume])
interaction.set_parameters(parameters)
self._self_interaction_dict[ID] = interaction
return interaction
## elastic-dipole interactions
[docs]
def add_homo_strain_dipole_interaction(self, ID, field_1_ID, field_2_ID, B1xx, B1yy, B4yz, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = mutual_interaction(field_1_ID, field_2_ID)
interaction.set_energy_engine(homo_strain_dipole_interaction, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([B1xx, B1yy, B4yz]))
self._mutual_interaction_dict[ID] = interaction
return interaction
[docs]
def add_inhomo_strain_dipole_interaction(self, ID, field_1_ID, field_2_ID, B1xx, B1yy, B4yz, enable_jit=True):
self._add_interaction_sanity_check(ID)
interaction = mutual_interaction(field_1_ID, field_2_ID)
energy_engine = get_inhomo_strain_dipole_interaction(enable_jit=enable_jit)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([B1xx, B1yy, B4yz]))
self._mutual_interaction_dict[ID] = interaction
return interaction
## atomistic spin-type interactions
[docs]
def add_cubic_anisotropy_interaction(self, ID, field_ID, K1, K2, enable_jit=True):
"""
Add the cubic anisotropy interaction term.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
K1 : float
First anisotropy constant
K2 : float
Second anisotropy constant
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(cubic_anisotropy_energy, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([K1, K2]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def _add_isotropic_exchange_interaction_by_rollers(self, ID, field_ID, coupling, rollers, enable_jit=True):
"""
Add the isotropic exchange interaction term H=sum_{i~j} Jij*Si*Sj to the Hamiltonian.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
coupling : float
Coupling constant
rollers : list
List of rolling functions for specifying the neighbouring relationship
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
energy_engine = get_isotropic_exchange_energy_engine(rollers)
interaction = self_interaction(field_ID)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
interaction.set_parameters(jnp.array([coupling]))
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_isotropic_exchange_interaction_1st_shell(self, ID, field_ID, coupling, enable_jit=True):
"""
Add the first shell isotropic exchange interaction term.
The first shell is defined in lattice class.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
coupling : float
Coupling constant
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
interaction = self._add_isotropic_exchange_interaction_by_rollers(
ID, field_ID, coupling, self.lattice.first_shell_roller, enable_jit=enable_jit)
return interaction
[docs]
def add_isotropic_exchange_interaction_2nd_shell(self, ID, field_ID, coupling, enable_jit=True):
"""
Add the second shell isotropic exchange interaction term.
The second shell is defined in lattice class.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
coupling : float
Coupling constant
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
interaction = self._add_isotropic_exchange_interaction_by_rollers(
ID, field_ID, coupling, self.lattice.second_shell_roller, enable_jit=enable_jit)
return interaction
[docs]
def add_isotropic_exchange_interaction_3rd_shell(self, ID, field_ID, coupling, enable_jit=True):
"""
Add the third shell isotropic exchange interaction term.
The third shell is defined in lattice class.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
coupling : float
Coupling constant
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
interaction = self._add_isotropic_exchange_interaction_by_rollers(
ID, field_ID, coupling, self.lattice.third_shell_roller, enable_jit=enable_jit)
return interaction
[docs]
def add_isotropic_exchange_interaction_4th_shell(self, ID, field_ID, coupling, enable_jit=True):
"""
Add the fourth shell isotropic exchange interaction term.
The fourth shell is defined in lattice class.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
coupling : float
Coupling constant
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
interaction = self._add_isotropic_exchange_interaction_by_rollers(
ID, field_ID, coupling, self.lattice.fourth_shell_roller, enable_jit=enable_jit)
return interaction
"""
Methods for adding custom interactions to the Hamiltonian. Energy engines should be provided by the user.
"""
[docs]
def add_self_interaction(self, ID, field_ID, energy_engine, parameters=None, enable_jit=True):
"""
Add a custom self-interaction term to the Hamiltonian.
Parameters
----------
ID : str
ID of the interaction
field_ID : str
ID of the field
energy_engine : callable
A function that takes the field as input and returns the interaction energy
parameters : array-like, optional
Parameters for the interaction
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
interaction = self_interaction( field_ID)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
if parameters is not None:
interaction.set_parameters(parameters)
self._self_interaction_dict[ID] = interaction
return interaction
[docs]
def add_mutual_interaction(self, ID, field_1_ID, field_2_ID, energy_engine, parameters=None, enable_jit=True):
"""
Add a custom mutual interaction term to the Hamiltonian.
Parameters
----------
ID : str
ID of the interaction
field_1_ID : str
ID of the first field
field_2_ID : str
ID of the second field
energy_engine : callable
A function that takes the fields as input and returns the interaction energy
parameters : array-like, optional
Parameters for the interaction
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
interaction = mutual_interaction( field_1_ID, field_2_ID)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
if parameters is not None:
interaction.set_parameters(parameters)
self._mutual_interaction_dict[ID] = interaction
return interaction
[docs]
def add_triple_interaction(self, ID, field_1_ID, field_2_ID, field_3_ID, energy_engine, parameters=None, enable_jit=True):
"""
Add a custom triple interaction term to the Hamiltonian.
Parameters
----------
ID : str
ID of the interaction
field_1_ID : str
ID of the first field
field_2_ID : str
ID of the second field
field_3_ID : str
ID of the third field
energy_engine : callable
A function that takes the fields as input and returns the interaction energy
parameters : array-like, optional
Parameters for the interaction
enable_jit : bool, optional
Whether to use JIT compilation
Returns
-------
Interaction
The created interaction
"""
self._add_interaction_sanity_check(ID)
interaction = triple_interaction(field_1_ID, field_2_ID, field_3_ID)
interaction.set_energy_engine(energy_engine, enable_jit=enable_jit)
interaction.create_force_engine(enable_jit=enable_jit)
if parameters is not None:
interaction.set_parameters(parameters)
self._triple_interaction_dict[ID] = interaction
return interaction
"""
Methods for energy and force calculation
"""
[docs]
def calc_energy_by_ID(self, interaction_ID):
"""
Calculate the energy of an interaction by ID.
Parameters
----------
interaction_ID : str
ID of the interaction
Returns
-------
float
The energy of the interaction
Raises
------
ValueError
If interaction with given ID does not exist
"""
if interaction_ID in self._self_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field = self.get_field_by_ID(interaction.field_ID)
energy = interaction.calc_energy(field)
elif interaction_ID in self._mutual_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
energy = interaction.calc_energy(field1, field2)
elif interaction_ID in self._triple_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
field3 = self.get_field_by_ID(interaction.field_3_ID)
energy = interaction.calc_energy(field1, field2, field3)
else:
raise ValueError("Interaction with ID {} does not exist. Existing interactions: {} {}".format(interaction_ID,
self._self_interaction_dict.keys(), self._mutual_interaction_dict.keys()))
return energy
[docs]
def calc_force_by_ID(self, interaction_ID):
"""
Calculate the gradient force from an interaction by ID.
Parameters
----------
interaction_ID : str
ID of the interaction
Returns
-------
array
The gradient force of the interaction
Raises
------
ValueError
If interaction with given ID does not exist
"""
if interaction_ID in self._self_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field = self.get_field_by_ID(interaction.field_ID)
force = interaction.calc_force(field)
elif interaction_ID in self._mutual_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
force = interaction.calc_force(field1, field2)
elif interaction_ID in self._triple_interaction_dict:
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
field3 = self.get_field_by_ID(interaction.field_3_ID)
force = interaction.calc_force(field1, field2, field3)
else:
raise ValueError("Interaction with this ID does not exist. Existing interactions: ",
self._self_interaction_dict.keys(), self._mutual_interaction_dict.keys())
return force
[docs]
def calc_total_self_energy(self):
"""
Calculate the total self-interaction energy.
Returns
-------
float
Total self-interaction energy
"""
energy = 0.0
for interaction_ID in self._self_interaction_dict:
energy += self.calc_energy_by_ID(interaction_ID)
e = self.calc_energy_by_ID(interaction_ID)
# logging.info('Energy from {}: {}'.format(interaction_ID, e))
return energy
[docs]
def calc_total_mutual_interaction(self):
"""
Calculate the total mutual interaction energy.
Returns
-------
float
Total mutual interaction energy
"""
energy = 0.0
for interaction_ID in self._mutual_interaction_dict:
energy += self.calc_energy_by_ID(interaction_ID)
e = self.calc_energy_by_ID(interaction_ID)
# logging.info('Energy from {}: {}'.format(interaction_ID, e))
return energy
[docs]
def calc_total_triple_interaction(self):
"""
Calculate the total triple interaction energy.
Returns
-------
float
Total triple interaction energy
"""
energy = 0.0
for interaction_ID in self._triple_interaction_dict:
energy += self.calc_energy_by_ID(interaction_ID)
return energy
[docs]
def calc_total_potential_energy(self):
"""
Calculate the total potential energy of the system.
Calculate the total potential energy of the system.
Total potential energy is the sum of self-interaction energy, mutual interaction energy, and triple interaction energy.
"""
return self.calc_total_self_energy() + self.calc_total_mutual_interaction() + self.calc_total_triple_interaction()
[docs]
def calc_total_kinetic_energy(self):
"""
Calculate the total kinetic energy of the system.
"""
kinetic_energy = 0.0
for field in self.get_all_fields():
kinetic_energy += field.get_kinetic_energy()
return kinetic_energy
[docs]
def calc_temp_by_ID(self, ID):
"""
Calculate the temperature of a field.
"""
field = self.get_field_by_ID(ID)
return field.get_temperature()
[docs]
def calc_excess_stress(self):
'''
Get instantaneous stress - applied stress (e.g. from hydrostatic pressure)
'''
field = self.get_field_by_ID('gstrain')
return field.get_excess_stress()
"""
Methods for updating the gradient force
"""
[docs]
def update_force_from_self_interaction(self, profile=False):
"""
update the gradient force felt by each field from self interactions.
"""
for interaction_ID in self._self_interaction_dict:
if profile:
t0 = timer()
interaction = self._self_interaction_dict[interaction_ID]
field = self.get_field_by_ID(interaction.field_ID)
force = interaction.calc_force(field)
field.accumulate_force(force)
if profile:
jax.block_until_ready(field.get_force())
logging.info('Time for updating force from {}: {:.8f}s'.format(interaction_ID, timer()-t0))
# print("energy from %s:" % interaction_ID, interaction.calc_energy(field))
return
[docs]
def update_force_from_mutual_interaction(self, profile=False):
"""
update the gradient force felt by each field from mutual interactions.
"""
for interaction_ID in self._mutual_interaction_dict:
if profile:
t0 = timer()
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
force1, force2 = interaction.calc_force(field1, field2)
field1.accumulate_force(force1)
field2.accumulate_force(force2)
if profile:
jax.block_until_ready(field2.get_force())
logging.info('Time for updating force from {}: {:.8f}s'.format(interaction_ID, timer()-t0))
return
[docs]
def update_force_from_triple_interaction(self, profile=False):
"""
update the gradient force felt by each field from triple interactions.
"""
for interaction_ID in self._triple_interaction_dict:
if profile:
t0 = timer()
interaction = self.get_interaction_by_ID(interaction_ID)
field1 = self.get_field_by_ID(interaction.field_1_ID)
field2 = self.get_field_by_ID(interaction.field_2_ID)
field3 = self.get_field_by_ID(interaction.field_3_ID)
force1, force2, force3 = interaction.calc_force(field1, field2, field3)
field1.accumulate_force(force1)
field2.accumulate_force(force2)
field3.accumulate_force(force3)
if profile:
jax.block_until_ready(field3.get_force())
logging.info('Time for updating force from {}: {:.8f}s'.format(interaction_ID, timer()-t0))
return
[docs]
def update_force(self, profile=False):
"""
update the gradient force felt by each field from all interactions.
"""
## zero force
for field in self.get_all_fields():
if profile:
t0 = timer()
field.zero_force()
if profile:
jax.block_until_ready(field.get_force())
logging.info('Time for zeroing force of {}: {:.8f}s'.format(field.ID, timer()-t0))
## update force from all interactions
self.update_force_from_self_interaction(profile=profile)
self.update_force_from_mutual_interaction(profile=profile)
self.update_force_from_triple_interaction(profile=profile)
return
[docs]
class RingPolymerSystem(System):
"""
A class to define a ring polymer system for path-integral molecular dynamics simulations.
"""
[docs]
def __init__(self, lattice, nbeads=1):
super().__init__(lattice)
self.nbeads = nbeads
raise NotImplementedError("Ring polymer system is not implemented yet.")