"""
Classes which define the fields on the lattice.
"""
# This file is part of OpenFerro.
import numpy as np
import jax
import jax.numpy as jnp
from openferro.units import Constants
from openferro.parallelism import DeviceMesh
from openferro.integrator.llg import *
from openferro.integrator.md import *
[docs]
class Field:
"""
Template class to define a field on a lattice.
"""
[docs]
def __init__(self, lattice, ID: str):
"""
Initialize a field.
Parameters
----------
lattice : BravaisLattice3D
Lattice object
ID : str
ID of the field
"""
self.lattice = lattice
self.ID = ID
self._values = None
self._mass = None
self._velocity = None
self._force = None
self._sharding = None
self.integrator = None
self.integrator_class = None
"""
These methods are used to handle the values of the field.
"""
[docs]
def set_values(self, values):
self._values = values
return
[docs]
def get_values(self):
"""
Get the values of the field.
Returns
-------
array_like
Values of the field
Raises
------
ValueError
If field has no values set
"""
if self._values is None:
raise ValueError("Field has no values. Set values before getting them.")
else:
return self._values
@property
def size(self):
return self.get_values().size
"""
These methods are used to handle the mass of the field.
"""
[docs]
def set_mass(self, mass):
if self._values is None:
raise ValueError("Set field values before setting mass.")
else:
assert jnp.min(mass) >= 0.0, "Mass must be non-negative"
if jnp.isscalar(mass):
self._mass = jnp.zeros_like(self._values[..., 0]) + mass
self._mass = self._mass[..., None]
elif mass.shape == self._values[..., 0].shape:
self._mass = mass[..., None]
else:
raise ValueError("Mass must be a scalar or an array of the same size as the all but the last dimension of the field values.")
[docs]
def get_mass(self):
if self._mass is None:
raise ValueError("Mass is not set")
else:
return self._mass
"""
These methods are used to handle the velocity of the field.
"""
[docs]
def set_velocity(self, velocity):
if self._values is None:
raise ValueError("Field has no values. Set values before setting velocity.")
else:
self.compare_shape(velocity, self._values)
self._velocity = velocity
[docs]
def get_velocity(self):
if self._velocity is None:
raise ValueError("Velocity is not set")
else:
return self._velocity
[docs]
def init_velocity(self, mode='zero', temperature=None):
if self._values is None:
raise ValueError("Set field values before initializing velocity.")
else:
if mode == 'zero':
self._velocity = jnp.zeros_like(self._values)
elif mode == 'gaussian':
key = jax.random.PRNGKey(np.random.randint(0, 1000000))
self._velocity = jax.random.normal(key, self._values.shape) * jnp.sqrt(1 / self._mass * Constants.kb * temperature)
if self._sharding is not None:
self._velocity = jax.device_put(self._velocity, self._sharding)
"""
These methods are used to handle the force of the field.
"""
[docs]
def set_force(self, force):
if self._values is None:
raise ValueError("Field has no values. Set values before setting forces.")
else:
self.compare_shape(force, self._values)
self._force = force
[docs]
def get_force(self):
if self._force is None:
raise ValueError("Force do not exist")
else:
return self._force
[docs]
def zero_force(self):
if self._values is None:
raise ValueError("Field has no values. Set values before zeroing forces.")
else:
self._force = jnp.zeros_like(self._values)
[docs]
def accumulate_force(self, force):
if self._force is None:
raise ValueError("Gradients do not exist. Set or zero forces before accumulating.")
else:
self.compare_shape(force, self._force)
self._force += force
"""
These methods are used to handle the energy of the field.
"""
[docs]
def get_kinetic_energy(self):
if self._velocity is None:
# raise ValueError("Velocity is not set")
return 0.0
elif self._mass is None:
# raise ValueError("Mass is not set")
return 0.0
else:
return 0.5 * jnp.sum(self._mass * jnp.square(self._velocity))
[docs]
def get_temperature(self):
if self._velocity is None:
# raise ValueError("Velocity is not set")
return 0.0
elif self._mass is None:
# raise ValueError("Mass is not set")
return 0.0
else:
return jnp.mean(self._mass * jnp.square(self._velocity)) / Constants.kb
"""
Utility methods
"""
[docs]
def compare_shape(self, x, y):
if x.shape != y.shape:
raise ValueError("The two arrays to be compared have different shapes.")
[docs]
def compare_sharding(self, x, y):
if x.sharding != y.sharding:
raise ValueError("The two arrays to be compared has different sharding patterns.")
[docs]
def to_multi_devs(self, mesh: DeviceMesh):
sharding = mesh.partition_sharding()
if self._values is None:
raise ValueError("Field has no values. Set values before put the array to multiple devices.")
else:
self._values = jax.device_put(self._values, sharding)
self._sharding = sharding
if self._mass is not None:
self._mass = jax.device_put(self._mass, sharding)
if self._velocity is not None:
self._velocity = jax.device_put(self._velocity, sharding)
if self._force is not None:
self._force = jax.device_put(self._force, sharding)
"""
These methods are used to handle the integrator of the field.
"""
[docs]
def set_integrator(self, integrator_class, dt, **kwargs):
"""
Set the integrator according to the given integrator class. Set the time step.
To be implemented by the subclasses.
Parameters
----------
integrator_class : str
Class of integrator to use
dt : float
Time step
**kwargs
Additional arguments passed to integrator
"""
pass
[docs]
def set_custom_integrator(self, integrator):
self.integrator = integrator
[docs]
class FieldRn(Field):
"""
R^n field on a lattice. Values are stored as n-dimensional vectors.
"""
[docs]
def __init__(self, lattice, ID, dim, unit=None):
super().__init__(lattice, ID)
self.fdim = dim
self.ldim = lattice.dim
self.shape = [lattice.size[i] for i in range(self.ldim)] + [self.fdim]
self._values = jnp.zeros(self.shape)
self.unit = unit
self.integrator_class = {'optimization': GradientDescentIntegrator,
'adiabatic': LeapFrogIntegrator,
'isothermal': LangevinIntegrator}
@property
def mean(self):
"""
Calculate the average of the field over the lattice.
Returns
-------
array_like
Mean value of field
"""
return jnp.mean(self.get_values(), axis=[i for i in range(self.ldim)])
@property
def var(self):
"""
Calculate the variance of the field over the lattice.
Returns
-------
array_like
Variance of field
"""
return jnp.var(self.get_values(), axis=[i for i in range(self.ldim)])
[docs]
def set_local_value(self, loc, value):
"""
Set the value of the field at a given location.
Parameters
----------
loc : tuple
Location tuple with length equal to lattice dimension
value : array_like
Value to set at location
"""
assert type(loc) is tuple and len(loc) == self.ldim, "Location must be a tuple of length equal to the dimension of the lattice"
self._values[loc] = value
return
[docs]
def set_integrator(self, integrator_class, dt, temp=None, tau=None):
"""
Set the integrator according to the given integrator class. Set the time step.
Parameters
----------
integrator_class : str
Integrator class
dt : float
Time step
temp : float, optional
Temperature for isothermal integrator
tau : float, optional
Relaxation time for Langevin integrator
Raises
------
ValueError
If integrator class not supported or required parameters missing
"""
if integrator_class not in self.integrator_class:
raise ValueError(f"Integrator class {integrator_class} is not supported for this field.")
else:
if integrator_class == 'isothermal':
if temp is None or tau is None:
raise ValueError("Temperature and relaxation time must be specified for the isothermal integrator.")
else:
integrator = self.integrator_class[integrator_class](dt, temp, tau)
else:
integrator = self.integrator_class[integrator_class](dt)
self.integrator = integrator
return
[docs]
class FieldScalar(FieldRn):
"""
Scalar field. Values are stored as scalars.
Example: mass, density, any onsite constant, etc.
"""
[docs]
def __init__(self, lattice, ID, unit=None):
super().__init__(lattice, ID, dim=1, unit=unit)
[docs]
class FieldR3(FieldRn):
"""
3D vector field. Values are stored as 3-dimensional vectors.
Example: flexible dipole moment.
"""
[docs]
def __init__(self, lattice, ID, unit=None):
super().__init__(lattice, ID, dim=3, unit=unit)
[docs]
class FieldSO3(FieldRn):
"""
3D vector field with fixed magnitude and flexible orientation. Values are stored as 3-dimensional vectors.
Example: rigid atomistic spin, molecular orientation, etc.
"""
[docs]
def __init__(self, lattice, ID, unit=None):
super().__init__(lattice, ID, dim=3, unit=unit)
self._magnitude = jnp.ones(self.shape[:-1])
self.integrator_class = {'optimization': LLSIBIntegrator,
'adiabatic': ConservativeLLSIBIntegrator,
'isothermal': LLSIBLangevinIntegrator}
[docs]
def set_magnitude(self, magnitude):
if self._values is None:
raise ValueError("Field has no values. Set values before setting magnitude.")
elif jnp.isscalar(magnitude):
self._magnitude = jnp.ones(self.shape[:-1]) * magnitude
elif magnitude.shape == self._values.shape[:-1]:
self._magnitude = magnitude
else:
raise ValueError("Magnitude must be a scalar or an array of the same size as the all but the last dimension of the field values.")
self.normalize()
[docs]
def get_magnitude(self):
if self._magnitude is None:
raise ValueError("Magnitude is not set")
else:
return self._magnitude
[docs]
def perturb(self, sigma):
key = jax.random.PRNGKey(np.random.randint(0, 1000000))
self._values = self._values / jnp.linalg.norm(self._values, axis=-1, keepdims=True)
self._values += jax.random.normal(key, self._values.shape) * sigma
self.normalize()
return
[docs]
def normalize(self):
if self._values is None:
raise ValueError("Field has no values. Set values before normalizing.")
elif self._magnitude is None:
raise ValueError("Magnitude is not set. Set magnitude before normalizing.")
else:
self._values = self._values / jnp.linalg.norm(self._values, axis=-1, keepdims=True) * self._magnitude[..., None]
return
[docs]
def init_velocity(self, mode='zero', temperature=None):
pass
[docs]
def to_multi_devs(self, mesh: DeviceMesh):
sharding = mesh.partition_sharding()
if self._values is None:
raise ValueError("Field has no values. Set values before put the array to multiple devices.")
else:
self._values = jax.device_put(self._values, sharding)
self._sharding = sharding
if self._magnitude is not None:
self._magnitude = jax.device_put(self._magnitude, sharding)
if self._mass is not None:
self._mass = jax.device_put(self._mass, sharding)
if self._velocity is not None:
self._velocity = jax.device_put(self._velocity, sharding)
if self._force is not None:
self._force = jax.device_put(self._force, sharding)
return
[docs]
def set_integrator(self, integrator_class, dt, temp=None, alpha=None):
"""
Set the integrator according to the given integrator class. Set the time step.
Parameters
----------
integrator_class : str
Integrator class
dt : float
Time step
temp : float, optional
Temperature for isothermal integrator
alpha : float, optional
Gilbert damping constant for Landau-Lifshitz equation of motion
Raises
------
ValueError
If integrator class not supported or required parameters missing
"""
if integrator_class not in self.integrator_class:
raise ValueError(f"Integrator class {integrator_class} is not supported for this field.")
else:
if integrator_class == 'adiabatic':
integrator = self.integrator_class[integrator_class](dt)
elif integrator_class == 'optimization':
if alpha is None:
raise ValueError("Gilbert damping constant must be specified for the optimization integrator.")
else:
integrator = self.integrator_class[integrator_class](dt, alpha)
elif integrator_class == 'isothermal':
if alpha is None or temp is None:
raise ValueError("Gilbert damping constant and temperature must be specified for the isothermal integrator.")
else:
integrator = self.integrator_class[integrator_class](dt, temp, alpha)
self.integrator = integrator
return
[docs]
class LocalStrain3D(FieldRn):
"""
Strain field on 3D lattice are separated into local contribution (local strain field) and global contribution (homogeneous strain associated to the supercell).
The local strain field is encoded by the local displacement vector v_i(R)/a_i (a_i: the lattice vector) associated with each lattice site at R.
"""
[docs]
def __init__(self, lattice, ID):
super().__init__(lattice, ID, dim=3)
[docs]
@staticmethod
def get_local_strain_symmetric(values):
"""
Calculate the local strain field from the local displacement field.
Parameters
----------
values : array_like
Local displacement field values
Returns
-------
array_like
Local strain field with shape (l1, l2, l3, 6)
"""
padded_values = jnp.pad(values, ((1, 1), (1, 1), (1, 1), (0, 0)), mode='wrap') ## pad x,y,z axis with periodic boundary condition
grad_0, grad_1, grad_2 = jnp.gradient(padded_values, axis=(0, 1, 2))
grad_0 = grad_0[1:-1, 1:-1, 1:-1]
grad_1 = grad_1[1:-1, 1:-1, 1:-1]
grad_2 = grad_2[1:-1, 1:-1, 1:-1]
eta_1 = grad_0[..., 0] # eta_xx
eta_2 = grad_1[..., 1] # eta_yy
eta_3 = grad_2[..., 2] # eta_zz
eta_4 = (grad_1[..., 2] + grad_2[..., 1]) / 2 # eta_yz
eta_5 = (grad_0[..., 2] + grad_2[..., 0]) / 2 # eta_xz
eta_6 = (grad_0[..., 1] + grad_1[..., 0]) / 2 # eta_xy
local_strain = jnp.stack([eta_1, eta_2, eta_3, eta_4, eta_5, eta_6], axis=-1) # (l1, l2, l3, 6)
return local_strain
[docs]
@staticmethod
def get_local_strain(values):
"""
Calculate the local strain field from the local displacement field.
Implemented according to Physical Review B 52.9 (1995): 6301.
Parameters
----------
values : array_like
Local displacement field values
Returns
-------
array_like
Local strain field with shape (l1, l2, l3, 6)
"""
eta_1 = jnp.roll(values[..., 0], 1, axis=0) - values[..., 0] # vx(R-dx) - vx(R)
eta_1 = eta_1 + jnp.roll(eta_1, 1, axis=1) + jnp.roll(eta_1, 1, axis=2) + jnp.roll(jnp.roll(eta_1, 1, axis=1), 1, axis=2)
eta_1 = eta_1 / 4.0
eta_2 = jnp.roll(values[..., 1], 1, axis=1) - values[..., 1] # vy(R-dy) - vy(R)
eta_2 = eta_2 + jnp.roll(eta_2, 1, axis=0) + jnp.roll(eta_2, 1, axis=2) + jnp.roll(jnp.roll(eta_2, 1, axis=0), 1, axis=2)
eta_2 = eta_2 / 4.0
eta_3 = jnp.roll(values[..., 2], 1, axis=2) - values[..., 2] # vz(R-dz) - vz(R)
eta_3 = eta_3 + jnp.roll(eta_3, 1, axis=0) + jnp.roll(eta_3, 1, axis=1) + jnp.roll(jnp.roll(eta_3, 1, axis=0), 1, axis=1)
eta_3 = eta_3 / 4.0
eta_xy = jnp.roll(values[..., 1], 1, axis=0) - values[..., 1] # vy(R-dx) - vy(R)
eta_xy = eta_xy + jnp.roll(eta_xy, 1, axis=1) + jnp.roll(eta_xy, 1, axis=2) + jnp.roll(jnp.roll(eta_xy, 1, axis=1), 1, axis=2)
eta_yx = jnp.roll(values[..., 0], 1, axis=1) - values[..., 0] # vx(R-dy) - vx(R)
eta_yx = eta_yx + jnp.roll(eta_yx, 1, axis=0) + jnp.roll(eta_yx, 1, axis=2) + jnp.roll(jnp.roll(eta_yx, 1, axis=0), 1, axis=2)
eta_yz = jnp.roll(values[..., 2], 1, axis=1) - values[..., 2] # vz(R-dy) - vz(R)
eta_yz = eta_yz + jnp.roll(eta_yz, 1, axis=0) + jnp.roll(eta_yz, 1, axis=2) + jnp.roll(jnp.roll(eta_yz, 1, axis=0), 1, axis=2)
eta_zy = jnp.roll(values[..., 1], 1, axis=2) - values[..., 1] # vy(R-dz) - vy(R)
eta_zy = eta_zy + jnp.roll(eta_zy, 1, axis=0) + jnp.roll(eta_zy, 1, axis=1) + jnp.roll(jnp.roll(eta_zy, 1, axis=0), 1, axis=1)
eta_zx = jnp.roll(values[..., 0], 1, axis=2) - values[..., 0] # vx(R-dz) - vx(R)
eta_zx = eta_zx + jnp.roll(eta_zx, 1, axis=0) + jnp.roll(eta_zx, 1, axis=1) + jnp.roll(jnp.roll(eta_zx, 1, axis=0), 1, axis=1)
eta_xz = jnp.roll(values[..., 2], 1, axis=0) - values[..., 2] # vz(R-dx) - vz(R)
eta_xz = eta_xz + jnp.roll(eta_xz, 1, axis=1) + jnp.roll(eta_xz, 1, axis=2) + jnp.roll(jnp.roll(eta_xz, 1, axis=1), 1, axis=2)
eta_4 = (eta_yz + eta_zy) / 4.0
eta_5 = (eta_xz + eta_zx) / 4.0
eta_6 = (eta_xy + eta_yx) / 4.0
local_strain = jnp.stack([eta_1, eta_2, eta_3, eta_4, eta_5, eta_6], axis=-1) # (l1, l2, l3, 6)
return local_strain
[docs]
class GlobalStrain(Field):
"""
The homogeneous strain is represented by the strain tensor with Voigt convention, which is a 6-dimensional vector.
"""
[docs]
def __init__(self, lattice, ID):
super().__init__(lattice, ID)
self._values = jnp.zeros((6))
self.integrator_class = {'optimization': GradientDescentIntegrator_Strain,
'adiabatic': LeapFrogIntegrator_Strain,
'isothermal': LangevinIntegrator_Strain}
[docs]
def to_multi_devs(self, mesh: DeviceMesh):
sharding = mesh.replicate_sharding()
if self._values is None:
raise ValueError("Field has no values. Set values before put the array to multiple devices.")
else:
self._values = jax.device_put(self._values, sharding)
self._sharding = sharding
if self._mass is not None:
self._mass = jax.device_put(self._mass, sharding)
if self._velocity is not None:
self._velocity = jax.device_put(self._velocity, sharding)
if self._force is not None:
self._force = jax.device_put(self._force, sharding)
[docs]
def get_excess_stress(self):
return self.get_force() / self.lattice.ref_volume / Constants.bar
[docs]
def set_integrator(self, integrator_class, dt, temp=None, tau=None, freeze_x=False, freeze_y=False, freeze_z=False):
"""
Set the integrator according to the given integrator class. Set the time step.
Parameters
----------
integrator_class : str
Integrator class
dt : float
Time step
temp : float, optional
Temperature for isothermal integrator
tau : float, optional
Relaxation time for Langevin integrator
freeze_x : bool, optional
Whether to freeze x-component of strain, by default False
freeze_y : bool, optional
Whether to freeze y-component of strain, by default False
freeze_z : bool, optional
Whether to freeze z-component of strain, by default False
Raises
------
ValueError
If integrator class not supported or required parameters missing
"""
if integrator_class not in self.integrator_class:
raise ValueError(f"Integrator class {integrator_class} is not supported for this field.")
else:
if integrator_class == 'isothermal':
if temp is None or tau is None:
raise ValueError("Temperature and relaxation time must be specified for the isothermal integrator.")
else:
integrator = self.integrator_class[integrator_class](dt, temp, tau, freeze_x, freeze_y, freeze_z)
else:
integrator = self.integrator_class[integrator_class](dt, freeze_x, freeze_y, freeze_z)
self.integrator = integrator
return