"""
Classes which define the "interaction" between fields.
Each interaction is associated with a term in the Hamiltonian. Each interaction stores a function "self.energy_engine" that calculates the energy of the interaction and a function "force engine" that calculates the force of the interaction.
Only the energy engine is required. The force engine is optional. If the force engine is not set, the force will be calculated by automatic differentiation of the energy engine.
Notes
-----
This file is part of OpenFerro.
"""
import numpy as np
import jax.numpy as jnp
from jax import grad, jit
[docs]
class interaction_base:
"""
The base class to specify the interaction between fields.
"""
[docs]
def __init__(self, parameters=None):
self.parameters = parameters
self.energy_engine = None
self.force_engine = None
[docs]
def set_parameters(self, parameters):
"""
Set the parameters of the interaction.
Parameters
----------
parameters : array_like
The parameters of the interaction
Raises
------
ValueError
If parameters is not a numpy array, list, or jax array
"""
## turning list or numpy array to jax array
if isinstance(parameters, jnp.ndarray):
paras = parameters
elif isinstance(parameters, np.ndarray) or isinstance(parameters, list):
paras = jnp.array(parameters)
else:
raise ValueError("Parameters must be a numpy array, a list, or a jax array")
self.parameters = paras
[docs]
def get_parameters(self):
"""
Get the parameters of the interaction.
Returns
-------
jax.numpy.ndarray
The parameters of the interaction
"""
return self.parameters
[docs]
def set_energy_engine(self, energy_engine, enable_jit=True):
"""
Set the energy engine of the interaction.
Parameters
----------
energy_engine : callable
The energy engine of the interaction. It should take the values of the fields as input and return the energy as output.
enable_jit : bool, optional
Whether to enable jit for the energy engine, by default True
"""
if enable_jit:
self.energy_engine = jit(energy_engine)
else:
self.energy_engine = energy_engine
[docs]
def calc_energy(self):
pass
[docs]
def calc_force(self):
pass
[docs]
class self_interaction(interaction_base):
"""
A class to specify the self-interaction of a field.
Parameters
----------
field_ID : str
Identifier for the field
parameters : array_like, optional
Parameters for the interaction, by default None
"""
[docs]
def __init__(self, field_ID, parameters=None):
super().__init__( parameters)
self.field_ID = field_ID
[docs]
def create_force_engine(self, enable_jit=True):
"""
Derive the force engine of the interaction from the energy engine through automatic differentiation.
Parameters
----------
enable_jit : bool, optional
Whether to enable jit for the force engine, by default True
Raises
------
ValueError
If energy engine is not set
"""
if self.energy_engine is None:
raise ValueError("Energy engine is not set. Set energy engine before creating force engine.")
if enable_jit:
self.force_engine = jit(grad(self.energy_engine, argnums=0 ))
else:
self.force_engine = grad(self.energy_engine, argnums=0 )
return
[docs]
def calc_energy(self, field):
"""
Calculate the energy of the interaction for a given field.
Parameters
----------
field : Field
The field to calculate the energy
Returns
-------
float
The energy of the interaction
"""
field_values = field.get_values()
return self.energy_engine(field_values, self.parameters)
[docs]
def calc_force(self, field):
"""
Calculate the force of the interaction for a given field.
Parameters
----------
field : Field
The field to calculate the force
Returns
-------
jax.numpy.ndarray
The gradient of the energy with respect to the field. It has the same shape as the field.
"""
field_values = field.get_values()
gradient = self.force_engine(field_values, self.parameters)
return -gradient
[docs]
class mutual_interaction(interaction_base):
"""
A class to specify the mutual interaction between two fields.
Parameters
----------
field_1_ID : str
Identifier for the first field
field_2_ID : str
Identifier for the second field
parameters : array_like, optional
Parameters for the interaction, by default None
"""
[docs]
def __init__(self, field_1_ID, field_2_ID, parameters=None):
super().__init__( parameters)
self.field_1_ID = field_1_ID
self.field_2_ID = field_2_ID
[docs]
def create_force_engine(self, enable_jit=True):
"""
Derive the force engine of the interaction from the energy engine through automatic differentiation.
Parameters
----------
enable_jit : bool, optional
Whether to enable jit for the force engine, by default True
Raises
------
ValueError
If energy engine is not set
"""
if self.energy_engine is None:
raise ValueError("Energy engine is not set. Set energy engine before creating force engine.")
if enable_jit:
self.force_engine = jit(grad(self.energy_engine, argnums=(0, 1) ))
else:
self.force_engine = grad(self.energy_engine, argnums=(0, 1) )
[docs]
def calc_energy(self, field1, field2):
"""
Calculate the energy of the interaction for a given pair of fields.
Parameters
----------
field1 : Field
The first field
field2 : Field
The second field
Returns
-------
float
The energy of the interaction
"""
f1 = field1.get_values()
f2 = field2.get_values()
return self.energy_engine(f1, f2, self.parameters)
[docs]
def calc_force(self, field1, field2):
"""
Calculate the force of the interaction for a given pair of fields.
Parameters
----------
field1 : Field
The first field
field2 : Field
The second field
Returns
-------
tuple of jax.numpy.ndarray
The gradient of the energy with respect to the fields. It has the same shape as the fields.
"""
f1 = field1.get_values()
f2 = field2.get_values()
gradient = self.force_engine(f1, f2, self.parameters)
return (- gradient[0], - gradient[1])
[docs]
class triple_interaction:
"""
A class to specify the mutual interaction between three fields.
Parameters
----------
field_1_ID : str
Identifier for the first field
field_2_ID : str
Identifier for the second field
field_3_ID : str
Identifier for the third field
parameters : array_like, optional
Parameters for the interaction, by default None
"""
[docs]
def __init__(self, field_1_ID, field_2_ID, field_3_ID, parameters=None):
super().__init__( parameters)
self.field_1_ID = field_1_ID
self.field_2_ID = field_2_ID
self.field_3_ID = field_3_ID
[docs]
def create_force_engine(self, enable_jit=True):
"""
Derive the force engine of the interaction from the energy engine through automatic differentiation.
Parameters
----------
enable_jit : bool, optional
Whether to enable jit for the force engine, by default True
Raises
------
ValueError
If energy engine is not set
"""
if self.energy_engine is None:
raise ValueError("Energy engine is not set. Set energy engine before creating force engine.")
if enable_jit:
self.force_engine = jit(grad(self.energy_engine, argnums=(0, 1, 2) ))
else:
self.force_engine = grad(self.energy_engine, argnums=(0, 1, 2) )
[docs]
def calc_energy(self, field1, field2, field3):
"""
Calculate the energy of the interaction for a given triple of fields.
Parameters
----------
field1 : Field
The first field
field2 : Field
The second field
field3 : Field
The third field
Returns
-------
float
The energy of the interaction
"""
f1 = field1.get_values()
f2 = field2.get_values()
f3 = field3.get_values()
return self.energy_engine(f1, f2, f3, self.parameters)
[docs]
def calc_force(self, field1, field2, field3):
"""
Calculate the force of the interaction for a given triple of fields.
Parameters
----------
field1 : Field
The first field
field2 : Field
The second field
field3 : Field
The third field
Returns
-------
tuple of jax.numpy.ndarray
The gradient of the energy with respect to the fields. It has the same shape as the fields.
"""
f1 = field1.get_values()
f2 = field2.get_values()
f3 = field3.get_values()
gradient = self.force_engine(f1, f2, f3, self.parameters)
return (- gradient[0], - gradient[1], - gradient[2])