Source code for openferro.engine.multiferroic

"""derroic energy.

These functions will be added into <class interaction> for automatic differentiation.

Notes
-----
This file is part of OpenFerro.
"""
import jax.numpy as jnp
from jax import jit
from openferro.field import LocalStrain3D

##########################################################################################
## Utilities
##########################################################################################

[docs] def mean_field_on_interwaving_sublattice(field): """ Returns the mean-field of a field on an interwaving sublattice. Say we have two cubic sublattices A and B. We have a field defined on sublattice A. We want to calculate the mean-field (averaged over the 8 nearest neighbors of a B-site) of this field on sublattice B. Let the lattice constant be unit length. We assume that the origin of sublattice B is at (0,0,0) and the origin of sublattice A is at (-0.5,-0.5,-0.5). So (0,0,0) and (-0.5,-0.5,-0.5) will make a primitive cell. We will adopt this convention throughout this file. Note such a convention is not unique. It is just for consistency among the energy engines defined for multiferroics, to avoid confusion over the annoying amounts of jnp.roll. Parameters ---------- field : jnp.array The field to calculate the mean-field Returns ------- jnp.array The mean-field of the field """ mean_field = field + jnp.roll(field, [1,0,0], axis=[0,1,2]) mean_field += jnp.roll(field, [0,-1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [0,0,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,-1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [1,0,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,-1,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,-1,-1], axis=[0,1,2]) mean_field /= 8 return mean_field
[docs] def mean_field_1stnn_on_single_lattice(field): """ Returns the average of a field over the 6 nearest neighbors of a given site, on a cubic lattice. """ mean_field = jnp.roll(field, [1,0,0], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,0,0], axis=[0,1,2]) mean_field += jnp.roll(field, [0,1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [0,-1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [0,0,1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,0,-1], axis=[0,1,2]) mean_field /= 6 return mean_field
[docs] def mean_field_2ndnn_on_single_lattice(field): """ Returns the average of a field over the 12 nearest neighbors of a given site, on a cubic lattice. """ mean_field = jnp.roll(field, [1,1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [1,-1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,-1,0], axis=[0,1,2]) mean_field += jnp.roll(field, [1,0,1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,0,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,0,1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,0,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,1,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,-1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [0,-1,-1], axis=[0,1,2]) mean_field /= 12 return mean_field
[docs] def mean_field_3rdnn_on_single_lattice(field): """ Returns the average of a field over the 8 nearest neighbors of a given site, on a cubic lattice. """ mean_field = jnp.roll(field, [1,1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,-1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,-1,1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,1,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [1,-1,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,1,-1], axis=[0,1,2]) mean_field += jnp.roll(field, [-1,-1,-1], axis=[0,1,2]) mean_field /= 8 return mean_field
########################################################################################## ## Energies involving AFD field, strain and dipole field ##########################################################################################
[docs] def short_range_1stnn_uniaxial_quartic(field, parameters): """ Returns the short-range interaction of nearest neighbors for a :math:`R^3` field defined on a isotropic lattice with periodic boundary conditions. The energy is given by: .. math:: \sum_{i, \\alpha} K u^3_{i, \\alpha} (u_{i+\\alpha, \\alpha} + u_{i-\\alpha, \\alpha}) Parameters ---------- field : jnp.array The field to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ K = parameters[0] ## uni-axis interaction orthogonal to displacement direction f = field energy = 0 for axis in range(3): f_shifted = jnp.roll(f[..., axis], 1, axis=axis) + jnp.roll(f[..., axis], -1, axis=axis) energy += jnp.sum(f[..., axis]**3 * f_shifted) return energy * K
[docs] def short_range_1stnn_trilinear_two_sublattices(field_A, field_B, parameters): """ Returns the short-range interaction of nearest neighbors for two :math:`R^3` fields defined on interwaving sublattices (A and B sublattices) with periodic boundary conditions. The energy is given by: .. math:: \sum_{i \\sim j, \\alpha \\neq \\beta} D_{ij,\\alpha\\beta} a_{j, \\alpha} b_{i, \\alpha} b_{i, \\beta} Here i is the index of sublattice A and j is the index of sublattice B. :math:`i\sim j` means that i and j are nearest neighbors. :math:`D_{ij,\\alpha\\beta}` is the trilinear coupling constant. For interwaving cubic lattices, there are 8 such nearest neighbors for a given i. Define :math:`a'_{i}` to be the mean-field (averaged over the 8 nearest neighbors of a B-site) of field_A on the B site-i. The energy is given by .. math:: \sum_{i, \\alpha \\neq \\beta} 8 D^{\mathrm{nn}}_{\\alpha\\beta} a'_{i, \\alpha} b_{i, \\alpha} b_{i, \\beta} Parameters ---------- field_A : jnp.array The field of sublattice A to calculate the energy field_B : jnp.array The field of sublattice B to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ D = parameters[0] ## trilinear interaction energy = 0 ## the mean-field of field_A on a B-site. So we need to average the field_A over the 8 nearest neighbors of a B-site. mean_field_A = mean_field_on_interwaving_sublattice(field_A) mean_field_A *= field_B energy += jnp.sum(mean_field_A * jnp.roll(field_B, 1, axis=-1)) energy += jnp.sum(mean_field_A * jnp.roll(field_B, -1, axis=-1)) # energy += jnp.sum(mean_field_A[..., 0] * (field_B[..., 1] + field_B[..., 2])) # energy += jnp.sum(mean_field_A[..., 1] * (field_B[..., 2] + field_B[..., 0])) # energy += jnp.sum(mean_field_A[..., 2] * (field_B[..., 0] + field_B[..., 1])) return energy * D * 8 ## TODO: check if the factor 8 is correct
[docs] def short_range_1stnn_biquadratic_iiii_two_sublattices(field_A, field_B, parameters): """ Returns the short-range interaction of nearest neighbors for two :math:`R^3` fields defined on interwaving sublattices (A and B sublattices) with periodic boundary conditions. Let the lattice constant be unit length. We assume that the origin of field_B is at (0,0,0) and the origin of field_A is at (-0.5,-0.5,-0.5). So (0,0,0) and (-0.5,-0.5,-0.5) will make a primitive cell. We will adopt this convention implementing the sum over i and j. Note such a convention is not unique. It is just for consistency among the energy engines defined for multiferroics, to avoid confusion over the annoying amounts of jnp.roll. Define :math:`a'_{i}` to be the mean-field (averaged over the 8 nearest neighbors of a B-site) of field_A on the B site-i. The energy is given by: .. math:: \sum_{i, \\alpha \\beta \\gamma \\delta} E_{\\alpha\\beta\\gamma\\delta} b_{i, \\alpha} b_{i, \\beta} a'_{i, \\gamma} a'_{i, \\delta} Parameters ---------- field_A : jnp.array The field of sublattice A to calculate the energy field_B : jnp.array The field of sublattice B to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ Exxxx = parameters[0] ## biquadratic interaction Exxyy = parameters[1] Exyxy = parameters[2] energy = 0 ## the mean-field of field_A on a B-site. So we need to average the field_A over the 8 nearest neighbors of a B-site. mean_field_A = mean_field_on_interwaving_sublattice(field_A) ## uniaxial contribution energy += Exxxx * jnp.sum(field_B**2 * mean_field_A**2) ## xxyy, xxzz, yyzz, yyxx, zzxx, zzyy contributions energy += Exxyy * jnp.sum(field_B**2 * jnp.roll(mean_field_A**2, 1,axis=-1)) energy += Exxyy * jnp.sum(field_B**2 * jnp.roll(mean_field_A**2, -1,axis=-1)) ## xyxy, xzxz, yzyz, yxyx, zxzx, zyzy contributions field_AB = field_A * field_B energy += Exyxy * jnp.sum(field_AB * jnp.roll(field_AB, 1,axis=-1)) energy += Exyxy * jnp.sum(field_AB * jnp.roll(field_AB, -1,axis=-1)) return energy ## TODO: check if here a prefactor should be applied
########################################################################################## ## Energies involving spin field and all others ##########################################################################################
[docs] def short_range_biquadratic_ijii_two_sublattices(field_A, field_B, parameters): """ Returns the short-range interaction of nearest neighbors for two :math:`R^3` fields defined on interwaving sublattices (A and B sublattices) with periodic boundary conditions. Let the lattice constant be unit length. We assume that the origin of field_B is at (0,0,0) and the origin of field_A is at (-0.5,-0.5,-0.5). So (0,0,0) and (-0.5,-0.5,-0.5) will make a primitive cell. We will adopt this convention implementing the sum over i and j. Note such a convention is not unique. It is just for consistency among the energy engines defined for multiferroics, to avoid confusion over the annoying amounts of jnp.roll. Define :math:`a'_{i}` to be the mean-field (averaged over the 8 nearest neighbors of a B-site) of field_A on the B site-i. The energy is given by: .. math:: \sum_{ij, \\alpha \\beta \\gamma \\delta} E_{\\alpha\\beta\\gamma\\delta} b_{i, \\alpha} b_{j, \\beta} a'_{i, \\gamma} a'_{i, \\delta} Parameters ---------- field_A : jnp.array The field of sublattice A to calculate the energy field_B : jnp.array The field of sublattice B to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ ## TODO: check if these parameters needed to be divided by corresponding coordination number. Likely not. E1xxxx = parameters[0] E1xxyy = parameters[1] E1xyxy = parameters[2] E2xxxx = parameters[3] E2xxyy = parameters[4] E2xyxy = parameters[5] E3xxxx = parameters[6] E3xxyy = parameters[7] E3xyxy = parameters[8] energy = 0 ## the mean-field of field_A on a B-site. So we need to average the field_A over the 8 nearest neighbors of a B-site. mean_field_A = mean_field_on_interwaving_sublattice(field_A) field_B_1stnnsum = mean_field_1stnn_on_single_lattice(field_B) * 6 field_B_2ndnnsum = mean_field_2ndnn_on_single_lattice(field_B) * 12 field_B_3rdnnsum = mean_field_3rdnn_on_single_lattice(field_B) * 8 ## uniaxial contribution energy += jnp.sum(field_B * (E1xxxx * field_B_1stnnsum + E2xxxx * field_B_2ndnnsum + E3xxxx * field_B_3rdnnsum) * mean_field_A**2) ## xxyy, xxzz, yyzz, yyxx, zzxx, zzyy contributions energy += jnp.sum(field_B * (E1xxyy * field_B_1stnnsum + E2xxyy * field_B_2ndnnsum + E3xxyy * field_B_3rdnnsum) * jnp.roll(mean_field_A**2, 1,axis=-1)) energy += jnp.sum(field_B * (E1xxyy * field_B_1stnnsum + E2xxyy * field_B_2ndnnsum + E3xxyy * field_B_3rdnnsum) * jnp.roll(mean_field_A**2, -1,axis=-1)) ## xyxy, xzxz, yzyz, yxyx, zxzx, zyzy contributions energy += jnp.sum(field_B * mean_field_A * jnp.roll((E1xyxy * field_B_1stnnsum + E2xyxy * field_B_2ndnnsum + E3xyxy * field_B_3rdnnsum) * mean_field_A, 1,axis=-1)) energy += jnp.sum(field_B * mean_field_A * jnp.roll((E1xyxy * field_B_1stnnsum + E2xyxy * field_B_2ndnnsum + E3xyxy * field_B_3rdnnsum) * mean_field_A, -1,axis=-1)) return energy
[docs] def short_range_biquadratic_ijii(field_A, field_B, parameters): """ Returns the short-range interaction of nearest neighbors for two :math:`R^3` fields defined on a cubic lattice with periodic boundary conditions. The energy is given by: .. math:: \sum_{ij, \\alpha \\beta \\gamma \\delta} E_{\\alpha\\beta\\gamma\\delta} b_{i, \\alpha} b_{j, \\beta} a_{i, \\gamma} a_{i, \\delta} j is summed over the 1st, 2nd, 3rd nearest neighbors of i. Parameters ---------- field_A : jnp.array The field a to calculate the energy field_B : jnp.array The field b to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ ## TODO: check if these parameters needed to be divided by corresponding coordination number. Likely not. E1xxxx = parameters[0] E1xxyy = parameters[1] E1xyxy = parameters[2] E2xxxx = parameters[3] E2xxyy = parameters[4] E2xyxy = parameters[5] E3xxxx = parameters[6] E3xxyy = parameters[7] E3xyxy = parameters[8] energy = 0 ## the mean-field of field_A on a B-site. So we need to average the field_A over the 8 nearest neighbors of a B-site. field_B_1stnnsum = mean_field_1stnn_on_single_lattice(field_B) * 6 field_B_2ndnnsum = mean_field_2ndnn_on_single_lattice(field_B) * 12 field_B_3rdnnsum = mean_field_3rdnn_on_single_lattice(field_B) * 8 ## uniaxial contribution energy += jnp.sum(field_B * (E1xxxx * field_B_1stnnsum + E2xxxx * field_B_2ndnnsum + E3xxxx * field_B_3rdnnsum) * field_A**2) ## xxyy, xxzz, yyzz, yyxx, zzxx, zzyy contributions energy += jnp.sum(field_B * (E1xxyy * field_B_1stnnsum + E2xxyy * field_B_2ndnnsum + E3xxyy * field_B_3rdnnsum) * jnp.roll(field_A**2, 1,axis=-1)) energy += jnp.sum(field_B * (E1xxyy * field_B_1stnnsum + E2xxyy * field_B_2ndnnsum + E3xxyy * field_B_3rdnnsum) * jnp.roll(field_A**2, -1,axis=-1)) ## xyxy, xzxz, yzyz, yxyx, zxzx, zyzy contributions energy += jnp.sum(field_B * field_A * jnp.roll((E1xyxy * field_B_1stnnsum + E2xyxy * field_B_2ndnnsum + E3xyxy * field_B_3rdnnsum) * field_A, 1,axis=-1)) energy += jnp.sum(field_B * field_A * jnp.roll((E1xyxy * field_B_1stnnsum + E2xyxy * field_B_2ndnnsum + E3xyxy * field_B_3rdnnsum) * field_A, -1,axis=-1)) return energy
[docs] def homo_strain_spin_interaction(global_strain, spin_field, parameters): """ Returns the homogeneous strain spin interaction energy. The energy is given by: .. math:: \sum_{i,j,l,\\alpha,\\beta} G_{ij, l\\alpha\\beta} \\eta_{l} m_{i, \\alpha} m_{j, \\beta} l is index of strain under Voigt notation. When i and j are 1st nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{1nn, l\\alpha\\beta}` When i and j are 2nd nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{2nn, l\\alpha\\beta}` When i and j are 3rd nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{3nn, l\\alpha\\beta}` Parameters ---------- global_strain : jnp.array Shape=(6), the global strain of a supercell spin_field : jnp.array Shape=(nx, ny, nz, 3), the spin field parameters : jax.numpy array The parameters of the energy function containing G_1nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 1st nearest neighbors G_1nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 1st nearest neighbors G_1nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 1st nearest neighbors G_2nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 2nd nearest neighbors G_2nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 2nd nearest neighbors G_2nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 2nd nearest neighbors G_3nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 3rd nearest neighbors G_3nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 3rd nearest neighbors G_3nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 3rd nearest neighbors Returns ------- jnp.array The homogeneous strain dipole interaction energy """ ## in some publication these parameters are named differently: G_1nn_xxxx, G_1nn_xxyy, G_1nn_yzyz G_1nn_1xx, G_1nn_1yy, G_1nn_4yz = parameters[:3] G_2nn_1xx, G_2nn_1yy, G_2nn_4yz = parameters[3:6] G_3nn_1xx, G_3nn_1yy, G_3nn_4yz = parameters[6:9] gs = global_strain f = spin_field f_1stnnsum = mean_field_1stnn_on_single_lattice(f) * 6 f_2ndnnsum = mean_field_2ndnn_on_single_lattice(f) * 12 f_3rdnnsum = mean_field_3rdnn_on_single_lattice(f) * 8 # Calculate coef_mat directly without creating B_tensor coef_mat_1nn = jnp.array([ [G_1nn_1xx*gs[0] + G_1nn_1yy*(gs[1]+gs[2]), G_1nn_4yz*gs[5], G_1nn_4yz*gs[4]], [G_1nn_4yz*gs[5], G_1nn_1xx*gs[1] + G_1nn_1yy*(gs[0]+gs[2]), G_1nn_4yz*gs[3]], [G_1nn_4yz*gs[4], G_1nn_4yz*gs[3], G_1nn_1xx*gs[2] + G_1nn_1yy*(gs[0]+gs[1])] ]) coef_mat_2nn = jnp.array([ [G_2nn_1xx*gs[0] + G_2nn_1yy*(gs[1]+gs[2]), G_2nn_4yz*gs[5], G_2nn_4yz*gs[4]], [G_2nn_4yz*gs[5], G_2nn_1xx*gs[1] + G_2nn_1yy*(gs[0]+gs[2]), G_2nn_4yz*gs[3]], [G_2nn_4yz*gs[4], G_2nn_4yz*gs[3], G_2nn_1xx*gs[2] + G_2nn_1yy*(gs[0]+gs[1])] ]) coef_mat_3nn = jnp.array([ [G_3nn_1xx*gs[0] + G_3nn_1yy*(gs[1]+gs[2]), G_3nn_4yz*gs[5], G_3nn_4yz*gs[4]], [G_3nn_4yz*gs[5], G_3nn_1xx*gs[1] + G_3nn_1yy*(gs[0]+gs[2]), G_3nn_4yz*gs[3]], [G_3nn_4yz*gs[4], G_3nn_4yz*gs[3], G_3nn_1xx*gs[2] + G_3nn_1yy*(gs[0]+gs[1])] ]) ## 1st nearest neighbors energy = coef_mat_1nn[0,0] * (f[...,0] * f_1stnnsum[...,0]).sum() ## TODO: check if a factor 0.5 should be added for all diagonal terms energy += coef_mat_1nn[1,1] * (f[...,1] * f_1stnnsum[...,1]).sum() energy += coef_mat_1nn[2,2] * (f[...,2] * f_1stnnsum[...,2]).sum() energy += coef_mat_1nn[0,1] * (f[...,0] * f_1stnnsum[...,1]).sum() energy += coef_mat_1nn[0,2] * (f[...,0] * f_1stnnsum[...,2]).sum() energy += coef_mat_1nn[1,2] * (f[...,1] * f_1stnnsum[...,2]).sum() ## 2nd nearest neighbors energy += coef_mat_2nn[0,0] * (f[...,0] * f_2ndnnsum[...,0]).sum() energy += coef_mat_2nn[1,1] * (f[...,1] * f_2ndnnsum[...,1]).sum() energy += coef_mat_2nn[2,2] * (f[...,2] * f_2ndnnsum[...,2]).sum() energy += coef_mat_2nn[0,1] * (f[...,0] * f_2ndnnsum[...,1]).sum() energy += coef_mat_2nn[0,2] * (f[...,0] * f_2ndnnsum[...,2]).sum() energy += coef_mat_2nn[1,2] * (f[...,1] * f_2ndnnsum[...,2]).sum() ## 3rd nearest neighbors energy += coef_mat_3nn[0,0] * (f[...,0] * f_3rdnnsum[...,0]).sum() energy += coef_mat_3nn[1,1] * (f[...,1] * f_3rdnnsum[...,1]).sum() energy += coef_mat_3nn[2,2] * (f[...,2] * f_3rdnnsum[...,2]).sum() energy += coef_mat_3nn[0,1] * (f[...,0] * f_3rdnnsum[...,1]).sum() energy += coef_mat_3nn[0,2] * (f[...,0] * f_3rdnnsum[...,2]).sum() energy += coef_mat_3nn[1,2] * (f[...,1] * f_3rdnnsum[...,2]).sum() return energy
[docs] def homo_strain_spin_1stnn_interaction(global_strain, spin_field, parameters): """ Returns the homogeneous strain spin interaction energy. The energy is given by: .. math:: \sum_{i,j,l,\\alpha,\\beta} G_{ij, l\\alpha\\beta} \\eta_{l} m_{i, \\alpha} m_{j, \\beta} l is index of strain under Voigt notation. i and j are 1st nearest neighbors :math:`G_{ij, l\\alpha\\beta} = G_{l\\alpha\\beta}` Parameters ---------- global_strain : jnp.array Shape=(6), the global strain of a supercell spin_field : jnp.array Shape=(nx, ny, nz, 3), the spin field parameters : jax.numpy array The parameters of the energy function containing G_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 1st nearest neighbors G_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 1st nearest neighbors G_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 1st nearest neighbors Returns ------- jnp.array The homogeneous strain dipole interaction energy """ ## in some publication these parameters are named differently: G_1nn_xxxx, G_1nn_xxyy, G_1nn_yzyz G_1xx, G_1yy, G_4yz = parameters[:3] gs = global_strain f = spin_field f_1stnnsum = mean_field_1stnn_on_single_lattice(f) * 6 ## 1st nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy = (G_1xx*gs[0] + G_1yy*(gs[1]+gs[2])) * (f[...,0] * f_1stnnsum[...,0]).sum() energy += (G_1xx*gs[1] + G_1yy*(gs[0]+gs[2])) * (f[...,1] * f_1stnnsum[...,1]).sum() energy += (G_1xx*gs[2] + G_1yy*(gs[0]+gs[1])) * (f[...,2] * f_1stnnsum[...,2]).sum() energy += (G_4yz*gs[5]) * (f[...,0] * f_1stnnsum[...,1]).sum() energy += (G_4yz*gs[4]) * (f[...,0] * f_1stnnsum[...,2]).sum() energy += (G_4yz*gs[3]) * (f[...,1] * f_1stnnsum[...,2]).sum() return energy
[docs] def homo_strain_spin_2ndnn_interaction(global_strain, spin_field, parameters): """ Returns the homogeneous strain spin interaction energy. The energy is given by: .. math:: \sum_{i,j,l,\\alpha,\\beta} G_{ij, l\\alpha\\beta} \\eta_{l} m_{i, \\alpha} m_{j, \\beta} l is index of strain under Voigt notation. i and j are 2nd nearest neighbors :math:`G_{ij, l\\alpha\\beta} = G_{l\\alpha\\beta}` Parameters ---------- global_strain : jnp.array Shape=(6), the global strain of a supercell spin_field : jnp.array Shape=(nx, ny, nz, 3), the spin field parameters : jax.numpy array The parameters of the energy function containing G_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 2nd nearest neighbors G_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 2nd nearest neighbors G_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 2nd nearest neighbors Returns ------- jnp.array The homogeneous strain dipole interaction energy """ ## in some publication these parameters are named differently: G_1nn_xxxx, G_1nn_xxyy, G_1nn_yzyz G_1xx, G_1yy, G_4yz = parameters[:3] gs = global_strain f = spin_field f_2ndnnsum = mean_field_2ndnn_on_single_lattice(f) * 12 ## 2nd nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy = (G_1xx*gs[0] + G_1yy*(gs[1]+gs[2])) * (f[...,0] * f_2ndnnsum[...,0]).sum() energy += (G_1xx*gs[1] + G_1yy*(gs[0]+gs[2])) * (f[...,1] * f_2ndnnsum[...,1]).sum() energy += (G_1xx*gs[2] + G_1yy*(gs[0]+gs[1])) * (f[...,2] * f_2ndnnsum[...,2]).sum() energy += (G_4yz*gs[5]) * (f[...,0] * f_2ndnnsum[...,1]).sum() energy += (G_4yz*gs[4]) * (f[...,0] * f_2ndnnsum[...,2]).sum() energy += (G_4yz*gs[3]) * (f[...,1] * f_2ndnnsum[...,2]).sum() return energy
[docs] def homo_strain_spin_3rdnn_interaction(global_strain, spin_field, parameters): """ Returns the homogeneous strain spin interaction energy. The energy is given by: .. math:: \sum_{i,j,l,\\alpha,\\beta} G_{ij, l\\alpha\\beta} \\eta_{l} m_{i, \\alpha} m_{j, \\beta} l is index of strain under Voigt notation. i and j are 3rd nearest neighbors :math:`G_{ij, l\\alpha\\beta} = G_{l\\alpha\\beta}` Parameters ---------- global_strain : jnp.array Shape=(6), the global strain of a supercell spin_field : jnp.array Shape=(nx, ny, nz, 3), the spin field parameters : jax.numpy array The parameters of the energy function containing G_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 3rd nearest neighbors G_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 3rd nearest neighbors G_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 3rd nearest neighbors Returns ------- jnp.array The homogeneous strain dipole interaction energy """ ## in some publication these parameters are named differently: G_1nn_xxxx, G_1nn_xxyy, G_1nn_yzyz G_1xx, G_1yy, G_4yz = parameters[:3] gs = global_strain f = spin_field f_3rdnnsum = mean_field_3rdnn_on_single_lattice(f) * 8 ## 3rd nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy = (G_1xx*gs[0] + G_1yy*(gs[1]+gs[2])) * (f[...,0] * f_3rdnnsum[...,0]).sum() energy += (G_1xx*gs[1] + G_1yy*(gs[0]+gs[2])) * (f[...,1] * f_3rdnnsum[...,1]).sum() energy += (G_1xx*gs[2] + G_1yy*(gs[0]+gs[1])) * (f[...,2] * f_3rdnnsum[...,2]).sum() energy += (G_4yz*gs[5]) * (f[...,0] * f_3rdnnsum[...,1]).sum() energy += (G_4yz*gs[4]) * (f[...,0] * f_3rdnnsum[...,2]).sum() energy += (G_4yz*gs[3]) * (f[...,1] * f_3rdnnsum[...,2]).sum() return energy
[docs] def get_inhomo_strain_spin_interaction(enable_jit=True): """ Returns the inhomogeneous strain spin interaction function. Parameters ---------- enable_jit : bool, optional Whether to enable JIT compilation, by default True Returns ------- function The interaction function """ get_local_strain = jit(LocalStrain3D.get_local_strain) if enable_jit else LocalStrain3D.get_local_strain def inhomo_strain_spin_interaction(local_displacement, spin_field, parameters): """ Returns the inhomogeneous strain spin interaction energy. The energy is given by: .. math:: \sum_{i,j,l,\\alpha,\\beta} G_{ij, l\\alpha\\beta} \\eta_{i,l} m_{i, \\alpha} m_{j, \\beta} l is index of local strain under Voigt notation. When i and j are 1st nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{1nn, l\\alpha\\beta}` When i and j are 2nd nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{2nn, l\\alpha\\beta}` When i and j are 3rd nearest neighbors, :math:`G_{ij, l\\alpha\\beta} = G_{3nn, l\\alpha\\beta}` Parameters ---------- local_displacement : jnp.array Shape=(nx, ny, nz, 3), the local displacement field spin_field : jnp.array Shape=(nx, ny, nz, 3), the spin field parameters : jax.numpy array The parameters of the energy function containing G_1nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 1st nearest neighbors G_1nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 1st nearest neighbors G_1nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 1st nearest neighbors G_2nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 2nd nearest neighbors G_2nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 2nd nearest neighbors G_2nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 2nd nearest neighbors G_3nn_1xx : Strain-spin interaction constant G1xx (or Gxxxx) for 3rd nearest neighbors G_3nn_1yy : Strain-spin interaction constant G1yy (or Gxxyy) for 3rd nearest neighbors G_3nn_4yz : Strain-spin interaction constant G4yz (or Gxyxy) for 3rd nearest neighbors """ G_1nn_1xx, G_1nn_1yy, G_1nn_4yz = parameters[:3] G_2nn_1xx, G_2nn_1yy, G_2nn_4yz = parameters[3:6] G_3nn_1xx, G_3nn_1yy, G_3nn_4yz = parameters[6:9] ls = get_local_strain(local_displacement) # (l1, l2, l3, 6) f = spin_field f_1stnnsum = mean_field_1stnn_on_single_lattice(f) * 6 f_2ndnnsum = mean_field_2ndnn_on_single_lattice(f) * 12 f_3rdnnsum = mean_field_3rdnn_on_single_lattice(f) * 8 energy =0 ## 1st nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy += ( (G_1nn_1xx * ls[...,0] + G_1nn_1yy * (ls[...,1] + ls[...,2])) * f[...,0] * f_1stnnsum[...,0] ).sum() energy += ( (G_1nn_1xx * ls[...,1] + G_1nn_1yy * (ls[...,0] + ls[...,2])) * f[...,1] * f_1stnnsum[...,1] ).sum() energy += ( (G_1nn_1xx * ls[...,2] + G_1nn_1yy * (ls[...,0] + ls[...,1])) * f[...,2] * f_1stnnsum[...,2] ).sum() energy += (G_1nn_4yz * ls[...,5] * f[...,0] * f_1stnnsum[...,1]).sum() energy += (G_1nn_4yz * ls[...,4] * f[...,0] * f_1stnnsum[...,2]).sum() energy += (G_1nn_4yz * ls[...,3] * f[...,1] * f_1stnnsum[...,2]).sum() ## 2nd nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy += ( (G_2nn_1xx * ls[...,0] + G_2nn_1yy * (ls[...,1] + ls[...,2])) * f[...,0] * f_2ndnnsum[...,0] ).sum() energy += ( (G_2nn_1xx * ls[...,1] + G_2nn_1yy * (ls[...,0] + ls[...,2])) * f[...,1] * f_2ndnnsum[...,1] ).sum() energy += ( (G_2nn_1xx * ls[...,2] + G_2nn_1yy * (ls[...,0] + ls[...,1])) * f[...,2] * f_2ndnnsum[...,2] ).sum() energy += (G_2nn_4yz * ls[...,5] * f[...,0] * f_2ndnnsum[...,1]).sum() energy += (G_2nn_4yz * ls[...,4] * f[...,0] * f_2ndnnsum[...,2]).sum() energy += (G_2nn_4yz * ls[...,3] * f[...,1] * f_2ndnnsum[...,2]).sum() ## 3rd nearest neighbors, TODO: check if a factor 0.5 should be added for first three lines below energy += ( (G_3nn_1xx * ls[...,0] + G_3nn_1yy * (ls[...,1] + ls[...,2])) * f[...,0] * f_3rdnnsum[...,0] ).sum() energy += ( (G_3nn_1xx * ls[...,1] + G_3nn_1yy * (ls[...,0] + ls[...,2])) * f[...,1] * f_3rdnnsum[...,1] ).sum() energy += ( (G_3nn_1xx * ls[...,2] + G_3nn_1yy * (ls[...,0] + ls[...,1])) * f[...,2] * f_3rdnnsum[...,2] ).sum() energy += (G_3nn_4yz * ls[...,5] * f[...,0] * f_3rdnnsum[...,1]).sum() energy += (G_3nn_4yz * ls[...,4] * f[...,0] * f_3rdnnsum[...,2]).sum() energy += (G_3nn_4yz * ls[...,3] * f[...,1] * f_3rdnnsum[...,2]).sum() return energy return inhomo_strain_spin_interaction
[docs] def DM_AFD_1stnn(AFD_field, spin_field, parameters): """ Returns the Dzyaloshinskii-Moriya interaction (involving oxygen octahedral titling AFD mode) of nearest neighbors for atomistic spin field on cubic lattice with periodic boundary conditions. The energy is given by: .. math:: \frac{1}{2}\sum_{i\sim j} L (\\omega_i - \\omega_j)\cdot (m_i \cross m_j) Here i and j are nearest neighbors. Parameters ---------- AFD_field : jnp.array The AFD field (:math:`\\omega`) to calculate the energy spin_field : jnp.array The spin field (:math:`m`) to calculate the energy parameters : jax.numpy array The parameters of the energy function Returns ------- jnp.array The energy of the field """ L = parameters[0] energy = 0 for axis in range(3): AFD_diff = AFD_field - jnp.roll(AFD_field, 1, axis=axis) spin_cross = jnp.cross(spin_field, jnp.roll(spin_field, 1, axis=axis)) energy += jnp.sum(AFD_diff * spin_cross) return energy * L