Source code for openferro.engine.ewald

"""
Functions for Ewald summation
"""
# This file is part of OpenFerro.

import numpy as np
import jax
from jax import jit
import jax.numpy as jnp
from openferro.units import Constants

[docs] def get_dipole_dipole_ewald(latt, sharding=None): """Returns the function to calculate the energy of dipole-dipole interaction. Implemented according to Sec.5.3 of "Wang, D., et al. 'Ewald summation for ferroelectric perovksites with charges and dipoles.' Computational Materials Science 162 (2019): 314-321." Parameters ---------- latt : Lattice The lattice object containing size and lattice vectors sharding : jax.sharding.Sharding, optional Sharding specification for distributed arrays Returns ------- callable Function that calculates dipole-dipole interaction energy """ l1, l2, l3 = latt.size a1, a2, a3 = latt.latt_vec a1 = a1[0] a2 = a2[1] a3 = a3[2] ref_volume = a1 * a2 * a3 * l1 * l2 * l3 a = jnp.array([a1 , a2 , a3 ]) b = 2 * jnp.pi / a bmax = jnp.max(b) amin = 2 * np.pi / bmax alpha = 5 / amin gcut = 2 * np.pi * alpha sigma = 1.0 / alpha / jnp.sqrt(2.0) ## the ewald sigma parameter ## get coefficients coef_ksum = 1 / 2.0 / ref_volume / Constants.epsilon0 coef_rsum = 1 / 2.0 / jnp.pi / Constants.epsilon0 * alpha**3 / 3.0 / jnp.sqrt(jnp.pi) ## repeatation of Brillouin Zone n1 = int(gcut / b[0]) n2 = int(gcut / b[1]) n3 = int(gcut / b[2]) ## reciprocal space grid for first Brillouin zone (shifted) G_grid_1stBZ = jnp.stack( jnp.meshgrid( jnp.arange(0, l1) / l1 * b[0], jnp.arange(0, l2) / l2 * b[1], jnp.arange(0, l3) / l3 * b[2], indexing='ij'), axis=-1) # (l1, l2, l3, 3) ## plain version of UkGG # UkGG = jnp.zeros((l1, l2, l3, 3, 3)) # if sharding is not None: # G_grid_1stBZ = jax.device_put(G_grid_1stBZ, sharding) # UkGG = jax.device_put(UkGG, sharding) # for i1 in range(-n1,n1): # for i2 in range(-n2,n2): # for i3 in range(-n3,n3): # G_grid = G_grid_1stBZ + jnp.array([i1*b[0], i2*b[1], i3*b[2]]).reshape(1,1,1,3) # (l1, l2, l3, 3) # Uk_coef = jnp.exp( - 0.5 * sigma**2 * jnp.sum(G_grid**2, axis=-1) ) / jnp.sum(G_grid**2, axis=-1) # (l1, l2, l3) # if i1==0 and i2==0 and i3==0: # Uk_coef = Uk_coef.at[0,0,0].set(0.0) # UkGG += G_grid[:,:,:,None,:] * G_grid[:,:,:,:,None] * Uk_coef[:,:,:,None,None] ## memory-saving version of UkGG with Voigt notation ## Voigt notation of a symmetric 3X3 matrix: Six elements are respectively (0,0), (1,1), (2,2), (1,2), (0,2), (0,1)-entry of a symmetric 3X3 matrix. ## Slightly different from the original Voigt notation. We do not double count the fourth, fifth, and sixth elements here. UkGG = jnp.zeros((l1, l2, l3, 6)) if sharding is not None: G_grid_1stBZ = jax.device_put(G_grid_1stBZ, sharding) UkGG = jax.device_put(UkGG, sharding) for i1 in range(-n1,n1): for i2 in range(-n2,n2): for i3 in range(-n3,n3): G_grid = G_grid_1stBZ + jnp.array([i1*b[0], i2*b[1], i3*b[2]]).reshape(1,1,1,3) # (l1, l2, l3, 3) Uk_coef = jnp.exp( - 0.5 * sigma**2 * jnp.sum(G_grid**2, axis=-1) ) / jnp.sum(G_grid**2, axis=-1) # (l1, l2, l3) if i1==0 and i2==0 and i3==0: Uk_coef = Uk_coef.at[0,0,0].set(0.0) addition = jnp.stack( [G_grid[...,0]**2 * Uk_coef, G_grid[...,1]**2 * Uk_coef, G_grid[...,2]**2 * Uk_coef, G_grid[..., 1] * G_grid[..., 2] * Uk_coef, # Voigt notation of entry-(1,2) G_grid[..., 0] * G_grid[..., 2] * Uk_coef, # Voigt notation of entry-(0,2) G_grid[..., 0] * G_grid[..., 1] * Uk_coef], # Voigt notation of entry-(0,1) axis=-1) UkGG += addition G_grid_1stBZ = None G_grid = None Uk_coef = None ## define the computationally intensive functions to be jitted. def _ewald_ksum(F_fft3): """Getting ewald summation over the k-space with UkGG in Voigt notation. Parameters ---------- F_fft3 : ndarray Fast Fourier Transform of the field, shape=(l1, l2, l3, 3) Returns ------- float Ewald summation in k-space """ F_real = F_fft3.real F_imag = F_fft3.imag ewald_ksum = (F_real[..., 0]**2 + F_imag[..., 0]**2) * UkGG[..., 0] ewald_ksum += (F_real[..., 1]**2 + F_imag[..., 1]**2) * UkGG[..., 1] ewald_ksum += (F_real[..., 2]**2 + F_imag[..., 2]**2) * UkGG[..., 2] ewald_ksum += 2 * ((F_real[..., 0] * F_real[..., 1] + F_imag[..., 0] * F_imag[..., 1]) * UkGG[..., 5]) ewald_ksum += 2 * ((F_real[..., 0] * F_real[..., 2] + F_imag[..., 0] * F_imag[..., 2]) * UkGG[..., 4]) ewald_ksum += 2 * ((F_real[..., 1] * F_real[..., 2] + F_imag[..., 1] * F_imag[..., 2]) * UkGG[..., 3]) return ewald_ksum.sum() ## define the computationally intensive functions to be jitted. def _ewald_rsum(field): """Getting ewald summation over the real space. Parameters ---------- field : ndarray The values of the field, shape=(l1, l2, l3, 3) Returns ------- float Ewald summation in real space """ return jnp.sum(field**2) ewald_ksum_func = jit(_ewald_ksum) ewald_rsum_func = jit(_ewald_rsum) ## The main function of energy engine will not be jitted. jitting jnp.fft.fftn seems to lead to error. def energy_engine(field, parameters): """Calculate the energy of dipole-dipole interaction using Ewald summation. Parameters ---------- field : ndarray The values of the field, shape=(l1, l2, l3, 3) parameters : ndarray Array of parameters Returns ------- float The dipole-dipole interaction energy """ prefactor = parameters[0] ## calculate reciprocal space sum. UkGG is a symmetric (l1, l2, l3, 3, 3) matrix, so we only need to calculate half of it. F_fft3 = jnp.fft.fftn(field, axes=(0,1,2)) # (l1, l2, l3, 3) ######### compute the summation over k-space # ewald_ksum = ewald_ksum_func(F_fft3, UkGG) ewald_ksum = ewald_ksum_func(F_fft3) ######### compute the summation over real space ewald_rsum = ewald_rsum_func(field) return (coef_ksum * ewald_ksum - coef_rsum * ewald_rsum) * prefactor return energy_engine
# ## Helper functions # def _ewald_ksum(F_fft3, UkGG): ## warning on too long constant folding because UkGG is in the argument when jitted # """ # Getting ewald summation over the k-space with UkGG in Voigt notation. # Args: # F_fft3: jax.numpy array, shape=(l1, l2, l3, 3). Fast Fourier Transform of the field. # UkGG: jax.numpy array, shape=(l1, l2, l3, 6). UkGG in Voigt notation. # Returns: # jax.numpy array, shape=(1,) # """ # F_real = F_fft3.real # F_imag = F_fft3.imag # ewald_ksum = (F_real[..., 0]**2 + F_imag[..., 0]**2) * UkGG[..., 0] # ewald_ksum += (F_real[..., 1]**2 + F_imag[..., 1]**2) * UkGG[..., 1] # ewald_ksum += (F_real[..., 2]**2 + F_imag[..., 2]**2) * UkGG[..., 2] # ewald_ksum += 2 * ((F_real[..., 0] * F_real[..., 1] + F_imag[..., 0] * F_imag[..., 1]) * UkGG[..., 5]) # ewald_ksum += 2 * ((F_real[..., 0] * F_real[..., 2] + F_imag[..., 0] * F_imag[..., 2]) * UkGG[..., 4]) # ewald_ksum += 2 * ((F_real[..., 1] * F_real[..., 2] + F_imag[..., 1] * F_imag[..., 2]) * UkGG[..., 3]) # return ewald_ksum.sum() # def _ewald_rsum(field): # """ # Getting ewald summation over the real space. # Args: # field: jax.numpy array, shape=(l1, l2, l3, 3). The values of the field. # Returns: # jax.numpy array, shape=(1,) # """ # return jnp.sum(field**2) """ Archived versions of Ewald summation with higher memory usage. For testing purpose only. """ # def get_dipole_dipole_ewald_mid_mem_usage(latt): # """ # Returns the function to calculate the energy of dipole-dipole interaction. # Implemented according to Sec.5.3 of # "Wang, D., et al. "Ewald summation for ferroelectric perovksites with charges and dipoles." Computational Materials Science 162 (2019): 314-321." # """ # l1, l2, l3 = latt.size # a1, a2, a3 = latt.latt_vec # a1 = a1[0] # a2 = a2[1] # a3 = a3[2] # ref_volume = a1 * a2 * a3 * l1 * l2 * l3 # a = jnp.array([a1 , a2 , a3 ]) # b = 2 * jnp.pi / a # bmax = jnp.max(b) # amin = 2 * np.pi / bmax # alpha = 5 / amin # gcut = 2 * np.pi * alpha # sigma = 1.0 / alpha / jnp.sqrt(2.0) ## the ewald sigma parameter # ## get coefficients # coef_ksum = 1 / 2.0 / ref_volume / Constants.epsilon0 # coef_rsum = 1 / 2.0 / jnp.pi / Constants.epsilon0 * alpha**3 / 3.0 / jnp.sqrt(jnp.pi) # ## get reriprocal space grid # n1 = int(gcut / b[0]) # n2 = int(gcut / b[1]) # n3 = int(gcut / b[2]) # ng1, ng2, ng3 = l1*n1, l2*n2, l3*n3 # G_grid = jnp.stack( jnp.meshgrid( # jnp.arange(-ng1, ng1) / l1 * b[0], # jnp.arange(-ng2, ng2) / l2 * b[1], # jnp.arange(-ng3, ng3) / l3 * b[2], # indexing='ij'), axis=-1) # (2*ng1, 2*ng2, 2*ng3, 3) # G_grid = jnp.roll(G_grid, shift=(-ng1, -ng2, -ng3), axis=(0, 1, 2)) # move gamma point to (0,0,0) # G_grid = G_grid.reshape(2*n1, l1, 2*n2, l2, 2*n3, l3, 3) # G_grid = G_grid.transpose(1,3,5,0,2,4,6).reshape(l1,l2,l3,-1,3) # (l1, l2, l3, 8*n1*n2*n3, 3) # ## get coefficients for reciprocal space sum # Uk_coef = jnp.exp( - 0.5 * sigma**2 * jnp.sum(G_grid**2, axis=-1) ) / jnp.sum(G_grid**2, axis=-1) # (l1, l2, l3, 8*n1*n2*n3) # Uk_coef = Uk_coef.at[0,0,0,0].set(0.0) # mute Gamma point # ## sum over replica of first Brillouin zone first. This reduces the memory usage by a factor of 8*n1*n2*n3/3 # # UkGG = (G_grid[:,:,:,:,None,:] * G_grid[:,:,:,:,:,None] * Uk_coef[:,:,:,:,None,None]).sum(3) # (l1, l2, l3, 3, 3) # UkGG = jnp.zeros((l1, l2, l3, 3, 3)) # for i in range(Uk_coef.shape[-1]): # UkGG += G_grid[:,:,:,i,None,:] * G_grid[:,:,:,i,:,None] * Uk_coef[:,:,:,i,None,None] # G_grid = None # Uk_coef = None # def energy_engine(field, parameters): # Z = parameters['Z_star'] # epsilon_inf = parameters['epsilon_inf'] # ## calculate reciprocal space sum # F_fft3 = jnp.fft.fftn(field, axes=(0,1,2)) # (l1, l2, l3, 3) # ewald_ksum = (F_fft3.real[:,:,:,None,:] * F_fft3.real[:,:,:,:,None] * UkGG).sum() # ewald_ksum += (F_fft3.imag[:,:,:,None,:] * F_fft3.imag[:,:,:,:,None] * UkGG).sum() # ewald_ksum = coef_ksum * ewald_ksum # ## calculate real space sum # ewald_rsum = - coef_rsum * jnp.sum(field**2) # return (ewald_ksum + ewald_rsum) * Z**2 / epsilon_inf # return energy_engine # def get_dipole_dipole_ewald_high_memory_usage(latt): # """ # Returns the function to calculate the energy of dipole-dipole interaction. # Implemented according to Sec.5.3 of # "Wang, D., et al. "Ewald summation for ferroelectric perovksites with charges and dipoles." Computational Materials Science 162 (2019): 314-321." # """ # l1, l2, l3 = latt.size # a1, a2, a3 = latt.latt_vec # a1 = a1[0] # a2 = a2[1] # a3 = a3[2] # ref_volume = a1 * a2 * a3 * l1 * l2 * l3 # a = jnp.array([a1 , a2 , a3 ]) # b = 2 * jnp.pi / a # bmax = jnp.max(b) # amin = 2 * np.pi / bmax # alpha = 5 / amin # gcut = 2 * np.pi * alpha # sigma = 1.0 / alpha / jnp.sqrt(2.0) ## the ewald sigma parameter # ## get coefficients # coef_ksum = 1 / 2.0 / ref_volume / Constants.epsilon0 # coef_rsum = 1 / 2.0 / jnp.pi / Constants.epsilon0 * alpha**3 / 3.0 / jnp.sqrt(jnp.pi) # ## get reriprocal space grid # n1 = int(gcut / b[0]) # n2 = int(gcut / b[1]) # n3 = int(gcut / b[2]) # ng1, ng2, ng3 = l1*n1, l2*n2, l3*n3 # G_grid = jnp.stack( jnp.meshgrid( # jnp.arange(-ng1, ng1) / l1 * b[0], # jnp.arange(-ng2, ng2) / l2 * b[1], # jnp.arange(-ng3, ng3) / l3 * b[2], # indexing='ij'), axis=-1) # (2*ng1, 2*ng2, 2*ng3, 3) # G_grid = jnp.roll(G_grid, shift=(-ng1, -ng2, -ng3), axis=(0, 1, 2)) # move gamma point to (0,0,0) # G_grid = G_grid.reshape(2*n1, l1, 2*n2, l2, 2*n3, l3, 3) # G_grid = G_grid.transpose(1,3,5,0,2,4,6).reshape(l1,l2,l3,-1,3) # (l1, l2, l3, 8*n1*n2*n3, 3) # ## get coefficients for reciprocal space sum # Uk_coef = jnp.exp( - 0.5 * sigma**2 * jnp.sum(G_grid**2, axis=-1) ) / jnp.sum(G_grid**2, axis=-1) # (l1, l2, l3, 8*n1*n2*n3) # Uk_coef = Uk_coef.at[0,0,0,0].set(0.0) # mute Gamma point # def energy_engine(field, parameters): # Z = parameters['Z_star'] # epsilon_inf = parameters['epsilon_inf'] # ## calculate reciprocal space sum # F_fft3 = jnp.fft.fftn(field, axes=(0,1,2)) # (l1, l2, l3, 3) # Uk_squared = jnp.sum( F_fft3.real[:,:,:,None,:] * G_grid, axis=-1)**2 # Uk_squared += jnp.sum( F_fft3.imag[:,:,:,None,:] * G_grid, axis=-1)**2 # (l1, l2, l3, 8*n1*n2*n3) # ewald_ksum = coef_ksum * jnp.sum(Uk_coef * Uk_squared) # ## calculate real space sum # ewald_rsum = - coef_rsum * jnp.sum(field**2) # return (ewald_ksum + ewald_rsum) * Z**2 / epsilon_inf # return energy_engine
[docs] def dipole_dipole_ewald_plain(field, parameters): """Brute-force Ewald summation for dipole-dipole interaction. For benchmarking purpose only. Parameters ---------- field : ndarray The field values, shape=(l1, l2, l3, 3) parameters : dict Dictionary containing: a1 : float First lattice vector a2 : float Second lattice vector a3 : float Third lattice vector Z_star : float Born effective charge epsilon_inf : float High-frequency dielectric constant Returns ------- float The dipole-dipole interaction energy """ l1, l2, l3 = field.shape[0], field.shape[1], field.shape[2] a1 = parameters['a1'] a2 = parameters['a2'] a3 = parameters['a3'] Z = parameters['Z_star'] epsilon_inf = parameters['epsilon_inf'] ref_volume = a1 * a2 * a3 * l1 * l2 * l3 # if (a1[1] != 0) or (a1[2] != 0) or (a2[0] != 0) or (a2[2] != 0) or (a3[0] != 0) or (a3[1] != 0): # raise NotImplementedError("Ewald summation is only implemented for orthogonal lattice vectors") # else: a = jnp.array([a1 , a2 , a3 ]) b = 2 * jnp.pi / a bmax = jnp.max(b) amin = 2 * np.pi / bmax alpha = 5 / amin gcut = 2 * np.pi * alpha sigma = 1.0 / alpha / jnp.sqrt(2.0) ## the ewald sigma parameter ## get coefficients coef_ksum = 1 / 2.0 / ref_volume / Constants.epsilon0 coef_rsum = 1 / 2.0 / jnp.pi / Constants.epsilon0 * alpha**3 / 3.0 / jnp.sqrt(jnp.pi) ## get reriprocal space grid n1 = int(gcut / b[0]) n2 = int(gcut / b[1]) n3 = int(gcut / b[2]) ng1, ng2, ng3 = l1*n1, l2*n2, l3*n3 G_grid = jnp.stack( jnp.meshgrid( jnp.arange(-ng1, ng1) / l1 * b[0], jnp.arange(-ng2, ng2) / l2 * b[1], jnp.arange(-ng3, ng3) / l3 * b[2], indexing='ij'), axis=-1) # (2*ng1, 2*ng2, 2*ng3, 3) G_grid = jnp.roll(G_grid, shift=(-ng1, -ng2, -ng3), axis=(0,1,2)) # move gamma point to (0,0,0) G2_grid = jnp.sum(G_grid**2, axis=-1) ## calculate reciprocal space sum ewald_ksum = 0.0 ewald_rsum = 0.0 for i1 in range(l1): for i2 in range(l2): for i3 in range(l3): for j1 in range(l1): for j2 in range(l2): for j3 in range(l3): for alpha in range(3): for beta in range(3): rij = - jnp.array([i1 - j1, i2 - j2, i3 - j3]) * a Q_ijabk = jnp.exp(-0.5 * sigma**2 * G2_grid) / G2_grid Q_ijabk = Q_ijabk.at[0,0,0].set(0.0) # mute Gamma point Q_ijabk = Q_ijabk * G_grid[..., alpha] * G_grid[..., beta] Q_ijabk = Q_ijabk * jnp.cos(jnp.sum(G_grid * rij, axis=-1)) Q_ijabk = jnp.sum(Q_ijabk) * coef_ksum ewald_ksum += Q_ijabk * field[i1, i2, i3, alpha] * field[j1, j2, j3, beta] ## calculate real space sum for i1 in range(l1): for i2 in range(l2): for i3 in range(l3): for alpha in range(3): ewald_rsum -= coef_rsum * field[i1, i2, i3, alpha]**2 return (ewald_ksum + ewald_rsum) * Z**2 / epsilon_inf