Source code for abtem.prism.bloch

import warnings

    import cupy as cp
import numpy as np
import pandas as pd

from abtem.core.backend import get_array_module
from abtem.core.constants import kappa
from import energy2sigma, energy2wavelength
from abtem.parametrizations import validate_parametrization

def _F_reflection_conditions(hkl):
    all_even = (hkl % 2 == 0).all(axis=1)
    all_odd = (hkl % 2 == 1).all(axis=1)
    return all_even + all_odd

def _I_reflection_conditions(hkl):
    return (hkl.sum(axis=1) % 2 == 0).all(axis=1)

def _A_reflection_conditions(hkl):
    return (hkl[1:].sum(axis=1) % 2 == 0).all(axis=1)

def _B_reflection_conditions(hkl):
    return (hkl[:, [0, 1]].sum(axis=1) % 2 == 0).all(axis=1)

def _C_reflection_conditions(hkl):
    return (hkl[:-1].sum(axis=1) % 2 == 0).all(axis=1)

[docs] class StructureFactors:
[docs] def __init__(self, atoms, k_max, parametrization="lobato"): self._atoms = atoms self._k_max = k_max self._hkl = self._make_hkl_grid() self._g_vec = self._hkl @ self._atoms.cell.reciprocal() self._array = None self._parametrization = validate_parametrization(parametrization)
def __len__(self): return len(self._hkl) def _make_hkl_grid(self): num_tile = int(np.ceil(self._k_max / min(self.dg))) self._gpts = num_tile * 2 + 1 indices = np.fft.fftshift(np.fft.fftfreq(self._gpts, d=1 / self._gpts)).astype( int ) ya, xa, za = np.meshgrid(*(indices,) * 3) hkl = np.vstack([xa.ravel(), ya.ravel(), za.ravel()]).T return hkl @property def dg(self): return np.linalg.norm(self.atoms.cell.reciprocal(), axis=1) @property def hkl(self): return self._hkl @property def g_vec(self): return self._g_vec @property def g_vec_length(self): return np.linalg.norm(self._g_vec, axis=1) @property def gpts(self): return self._gpts @property def atoms(self): return self._atoms @property def k_max(self): return self._k_max def _calculate_scattering_factors(self): Z_unique, Z_inverse = np.unique(self.atoms.numbers, return_inverse=True) g_unique, g_inverse = np.unique(self.g_vec_length, return_inverse=True) f_e_uniq = np.zeros((Z_unique.size, g_unique.size), dtype=np.complex128) for idx, Z in enumerate(Z_unique): # lobato_lookup = single_atom_scatter() # lobato_lookup.get_scattering_factor([Z], [1.0], g_unique, units="A") # f_e_uniq[idx, :] = lobato_lookup.fe # f_e_uniq[idx, :] = LobatoParametrization().scattering_factor(Z)( # g_unique**2 # ) scattering_factor = self._parametrization.scattering_factor(Z) f_e_uniq[idx, :] = scattering_factor(g_unique**2) f_e_uniq[idx, g_unique > self.k_max] = 0. return f_e_uniq[np.ix_(Z_inverse, g_inverse)] def get_array(self, cache=True): if self._array is None: array = self._calculate_structure_factor() if cache: self._array = array return self._array else: return self._array def _calculate_structure_factor(self): positions = self.atoms.get_scaled_positions() f_e = self._calculate_scattering_factors() struct_factors = np.sum( f_e * np.exp(2.0j * np.pi * np.squeeze(positions[:, None, :] @ self.hkl.T)), axis=0, ) struct_factors /= self.atoms.cell.volume struct_factors[self.g_vec_length >= self.k_max] = 0.0 return struct_factors.reshape((self.gpts,) * 3) def get_potential(self): v = np.fft.ifftn(np.fft.ifftshift(self._calculate_structure_factor())) sampling = np.diag(self.atoms.cell) / self.gpts v = v * self.atoms.cell.volume / (kappa * v -= v.min() return v
[docs] def excitation_errors(g, wavelength): sg = (2 * g[:, 2] - wavelength * np.sum(g * g, axis=1)) / 2 return sg
[docs] class BlochWaves:
[docs] def __init__(self, structure_factors, energy, sg_max, k_max=None, correct=True): self._structure_factors = structure_factors self._energy = energy if k_max is None: k_max = self.structure_factors.k_max / 2 elif k_max > self.structure_factors.k_max / 2: warnings.warn( "provided k_max exceed half the k_max of the scattering factors, some couplings are not included" ) sg = excitation_errors(self.structure_factors.g_vec, self.wavelength) self._included_hkl = np.where( (self.structure_factors.g_vec_length <= k_max) & (np.abs(sg) <= sg_max) & _F_reflection_conditions(self.structure_factors.hkl) )[0] self.correct = correct
@property def included_hkl(self): return self._included_hkl @property def structure_factors(self): return self._structure_factors @property def energy(self): return self._energy @property def wavelength(self): return energy2wavelength( def excitation_errors(self): g = self.structure_factors.g_vec[self.included_hkl] return excitation_errors(g, self.wavelength) @property def size(self): return len(self.included_hkl) ** 2 * 128 * 0.125 def calculate_U_gmh(self): hkl = self.structure_factors.hkl[self.included_hkl] g_vec = self.structure_factors.g_vec[self.included_hkl] n_beams = len(hkl) gmh = np.array( ( (hkl[:, 0][None] - hkl[:, 0][:, None]).ravel(), (hkl[:, 1][None] - hkl[:, 1][:, None]).ravel(), (hkl[:, 2][None] - hkl[:, 2][:, None]).ravel(), ) ).T struct_factors = self._structure_factors.get_array() U_gmh = struct_factors[ gmh[:, 0] - self.structure_factors.hkl[:, 0].min(), gmh[:, 1] - self.structure_factors.hkl[:, 1].min(), gmh[:, 2] - self.structure_factors.hkl[:, 2].min(), ] prefactor = energy2sigma( / self.wavelength / np.pi / kappa # m0c2 = 5.109989461e5 # prefactor = (m0c2 + / m0c2 / np.pi U_gmh = prefactor * U_gmh U_gmh = U_gmh.reshape((n_beams, n_beams)) sg = self.excitation_errors() diag = 2 / self.wavelength * sg if self.correct: U_gmh /= np.sqrt(1 + self.wavelength * g_vec[:, 2][:, None]) * np.sqrt( 1 + self.wavelength * g_vec[:, 2][None] ) diag /= 1 + self.wavelength * g_vec[:, 2] np.fill_diagonal(U_gmh, diag) return U_gmh def _make_plane_wave(self): hkl = self.structure_factors.hkl[self.included_hkl] psi_0 = cp.zeros((len(hkl),)) psi_0[int(np.where((hkl == [0, 0, 0]).all(axis=1))[0])] = 1.0 return psi_0 def get_exit_wave(self, thicknesses, return_complex=False, device="cpu", tol=0.0): xp = get_array_module(device) U_gmh = self.calculate_U_gmh() U_gmh = xp.array(U_gmh) v, C = xp.linalg.eigh(U_gmh) gamma = v * self.wavelength / 2.0 if self.correct: C = C / xp.sqrt( 1 + self.wavelength * xp.array(self.structure_factors.g_vec[self.included_hkl][:, 2][:, None]) ) C_inv = xp.conjugate(C.T) psi_0 = self._make_plane_wave() psi = [ C @ (xp.exp(2.0j * xp.pi * thickness * gamma) * (C_inv @ psi_0)) for thickness in thicknesses ] psi = xp.stack(psi, axis=0) if return_complex: return psi else: intensities = xp.abs(psi) ** 2 intensities = xp.asnumpy(intensities) intensities = pd.DataFrame( { f"{h} {k} {l}": intensity for ((h, k, l), intensity) in zip( self.structure_factors.hkl[self.included_hkl], intensities.T ) if np.any(intensity > tol) } ) return intensities