Source code for openferro.utilities

"""
Utility functions.
"""
# This file is part of OpenFerro.

import numpy as np
import jax
import jax.numpy as jnp

[docs] def SO3_rotation(B: jnp.ndarray, dt: float): """ Calculate batched rotation matrix for proper rotation around magnetic field axis. Parameters ---------- B : jnp.ndarray The magnetic field with shape (\*, 3) dt : float The time step Returns ------- jnp.ndarray The rotation matrix with shape (\*, 3, 3) Notes ----- Returns the batched rotation matrix (shape=(\*, 3, 3)) of a proper rotation of a vector by an angle theta (shape=(\*,)) around the axes u (shape=(\*, 3)). See https://en.wikipedia.org/wiki/Rotation_matrix for details. The angle and axes are specified by the magnetic field B and time step dt as: u = B / jnp.linalg.norm(B, axis=-1)[..., None] theta = jnp.linalg.norm(B, axis=-1) * dt """ theta = jnp.linalg.norm(B, axis=-1) b = B / theta[..., None] bx = b[..., 0] by = b[..., 1] bz = b[..., 2] theta = theta * dt cos_theta = jnp.cos(theta) sin_theta = jnp.sin(theta) u = 1 - cos_theta ## initialize the rotation matrix rotation_matrix = jnp.stack([ bx*bx*u + cos_theta, bx*by*u-bz*sin_theta, bx*bz*u+by*sin_theta, bx*by*u+bz*sin_theta, by*by*u + cos_theta, by*bz*u-bx*sin_theta, bx*bz*u-by*sin_theta, by*bz*u+bx*sin_theta, bz*bz*u + cos_theta ], axis=-1) rotation_matrix = rotation_matrix.reshape(B.shape[:-1] + (3, 3)) return rotation_matrix