Source code for abtem.transfer

"""Module to describe the contrast transfer function (CTF) and the related apertures."""
from __future__ import annotations

import copy
import re
from abc import abstractmethod
from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING

import numpy as np

from abtem.core.axes import AxisMetadata, ParameterAxis
from abtem.core.axes import OrdinalAxis
from abtem.core.backend import cp
from abtem.core.backend import get_array_module
from abtem.core.complex import complex_exponential
from abtem.core.energy import (
    Accelerator,
    energy2wavelength,
    HasAcceleratorMixin,
    reciprocal_space_sampling_to_angular_sampling,
)
from abtem.core.fft import fft_crop
from abtem.core.grid import HasGridMixin, polar_spatial_frequencies, Grid
from abtem.core.utils import expand_dims_to_broadcast
from abtem.distributions import (
    BaseDistribution,
    _unpack_distributions,
    _validate_distribution,
)
from abtem.measurements import ReciprocalSpaceLineProfiles
from abtem.transform import ReciprocalSpaceMultiplication

if TYPE_CHECKING:
    from abtem.waves import BaseWaves


[docs]class BaseTransferFunction( ReciprocalSpaceMultiplication, HasAcceleratorMixin, HasGridMixin ): """Base class for transfer functions."""
[docs] def __init__( self, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, distributions: tuple[str] = (), ): self._accelerator = Accelerator(energy=energy) self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) super().__init__(distributions=distributions)
@abstractmethod def _evaluate_from_angular_grid(self, alpha, phi): pass @property def angular_sampling(self) -> tuple[float, float]: """The sampling in scattering angles of the transfer function [mrad].""" return reciprocal_space_sampling_to_angular_sampling( self.reciprocal_space_sampling, self.energy ) def _angular_grid(self, device) -> tuple[np.ndarray, np.ndarray]: xp = get_array_module(device) alpha, phi = polar_spatial_frequencies(self.gpts, self.sampling, xp=xp) alpha *= self.wavelength return alpha, phi def _evaluate_kernel(self, waves: BaseWaves = None) -> np.ndarray: """ Evaluate the array to be multiplied with the waves in reciprocal space. Parameters ---------- waves : BaseWaves, optional If given, the array will be evaluated to match the provided waves. Returns ------- kernel : np.ndarray or dask.array.Array """ if waves is None: device = "cpu" else: self.accelerator.match(waves) self.grid.match(waves) device = waves.device self.grid.check_is_defined() self.accelerator.check_is_defined() alpha, phi = self._angular_grid(device) return self._evaluate_from_angular_grid(alpha, phi) def to_diffraction_patterns( self, max_angle: float = None, gpts: int | tuple[int, int] = None, ): """ """ from abtem.measurements import DiffractionPatterns if (self.sampling is None) or (max_angle is not None): if max_angle is None and hasattr(self, "_max_semiangle_cutoff"): max_angle = self._max_semiangle_cutoff elif max_angle is None: raise RuntimeError() sampling = 1 / (max_angle * 1e-3) / 2 * self.wavelength else: sampling = self.sampling if self.gpts is None and gpts is None: gpts = 128 ctf = self.copy() ctf.sampling = sampling ctf.gpts = gpts array = ctf._evaluate_kernel() xp = get_array_module(array) diffraction_patterns = DiffractionPatterns( xp.fft.fftshift(array, axes=(-2, -1)), sampling=ctf.reciprocal_space_sampling, ensemble_axes_metadata=ctf.ensemble_axes_metadata, fftshift=False, metadata={"energy": self.energy}, ) return diffraction_patterns def show(self, **kwargs): return self.to_diffraction_patterns().show(**kwargs)
[docs]class BaseAperture(BaseTransferFunction): """Base class for apertures. Documented in the subclasses."""
[docs] def __init__( self, semiangle_cutoff: float | BaseDistribution = None, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, distributions: tuple[str] = (), ): self._semiangle_cutoff = semiangle_cutoff super().__init__( energy=energy, extent=extent, gpts=gpts, sampling=sampling, distributions=distributions, )
@property def metadata(self): metadata = {} if not isinstance(self.semiangle_cutoff, BaseDistribution): metadata["semiangle_cutoff"] = self.semiangle_cutoff return metadata @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: return [] @property def _max_semiangle_cutoff(self): if isinstance(self.semiangle_cutoff, BaseDistribution): return max(self.semiangle_cutoff.values) else: return self.semiangle_cutoff @property def nyquist_sampling(self) -> float: """Nyquist sampling corresponding to the semiangle cutoff of the aperture [Å].""" return 1 / (4 * self._max_semiangle_cutoff / self.wavelength * 1e-3) @property def semiangle_cutoff(self): """Semiangle cutoff of the aperture [mrad].""" return self._semiangle_cutoff @semiangle_cutoff.setter def semiangle_cutoff(self, semiangle_cutoff:float | BaseDistribution): self._semiangle_cutoff = semiangle_cutoff def _cropped_aperture(self): if self._max_semiangle_cutoff == np.inf: return self gpts = ( int(2 * np.ceil(self._max_semiangle_cutoff / self.angular_sampling[0])) + 3, int(2 * np.ceil(self._max_semiangle_cutoff / self.angular_sampling[1])) + 3, ) cropped_aperture = self.copy() cropped_aperture.gpts = gpts return cropped_aperture def _evaluate_from_cropped(self, gpts: tuple[int, int]): cropped = self._cropped_aperture() array = cropped._evaluate() array = fft_crop(array, gpts) return array
[docs]def soft_aperture( alpha: np.ndarray, phi: np.ndarray, semiangle_cutoff: float | np.ndarray, angular_sampling: tuple[float, float], ): """ Calculates an array with a disk of ones and a soft edge. Parameters ---------- alpha : 2D array Array of radial angles [mrad]. phi : 2D array Array of azimuthal angles [rad]. semiangle_cutoff : float or 1D array Semiangle cutoff(s) of the aperture(s). If given as an array, a 3D array is returned where the first dimension represents a different aperture for each item in the array of semiangle cutoffs. angular_sampling : tuple of float Reciprocal-space sampling in units of scattering angles [mrad]. Returns ------- soft_aperture_array : 2D or 3D np.ndarray """ xp = get_array_module(alpha) semiangle_cutoff = xp.array(semiangle_cutoff, dtype=xp.float32) base_ndims = len(alpha.shape) semiangle_cutoff, alpha = expand_dims_to_broadcast(semiangle_cutoff, alpha) semiangle_cutoff, phi = expand_dims_to_broadcast( semiangle_cutoff, phi, match_dims=[(-2, -1), (-2, -1)] ) angular_sampling = xp.array(angular_sampling, dtype=xp.float32) * 1e-3 denominator = xp.sqrt( (xp.cos(phi) * angular_sampling[0]) ** 2 + (xp.sin(phi) * angular_sampling[1]) ** 2 ) ndims = len(alpha.shape) zeros = (slice(None),) * (ndims - base_ndims) + (0,) * base_ndims denominator[zeros] = 1.0 array = xp.clip( (semiangle_cutoff - alpha) / denominator + 0.5, a_min=0.0, a_max=1.0 ) array[zeros] = 1.0 return array
[docs]def hard_aperture(alpha: np.ndarray, semiangle_cutoff: float | BaseDistribution): """ Calculates an array with a disk of ones and a soft edge. Parameters ---------- alpha : 2D array Array of radial angles [mrad]. semiangle_cutoff : float or 1D array Semiangle cutoff(s) of the aperture(s). If given as an array, a 3D array is returned where the first dimension represents a different aperture for each item in the array of semiangle cutoffs. Returns ------- hard_aperture_array : 2D or 3D np.ndarray """ xp = get_array_module(alpha) return xp.array(alpha <= semiangle_cutoff).astype(xp.float32)
[docs]class Aperture(BaseAperture): """ A circular aperture cutting off the wave function at a specified angle, employed in both STEM and HRTEM. The abrupt cutoff may be softened by tapering it. Parameters ---------- semiangle_cutoff : float or BaseDistribution The cutoff semiangle of the aperture [mrad]. Alternatively, a distribution of angles may be provided. soft : bool, optional If True, the edge of the aperture is softened (default is True). energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. """
[docs] def __init__( self, semiangle_cutoff: float | BaseDistribution, soft: bool = True, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, ): semiangle_cutoff = _validate_distribution(semiangle_cutoff) self._soft = soft super().__init__( distributions=("semiangle_cutoff",), energy=energy, semiangle_cutoff=semiangle_cutoff, extent=extent, gpts=gpts, sampling=sampling, )
@property def ensemble_axes_metadata(self) -> list[AxisMetadata]: ensemble_axes_metadata = [] if isinstance(self.semiangle_cutoff, BaseDistribution): ensemble_axes_metadata += [ ParameterAxis( label="semiangle cutoff", values=self.semiangle_cutoff, units="mrad", _ensemble_mean=self.semiangle_cutoff.ensemble_mean, ) ] return ensemble_axes_metadata @property def soft(self) -> bool: """True if the aperture has a soft edge.""" return self._soft def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) if self.semiangle_cutoff == xp.inf: return xp.ones_like(alpha) semiangle_cutoff = xp.array(self.semiangle_cutoff) * 1e-3 if ( self.soft and self.grid.check_is_defined(False) and not np.isscalar(alpha) and not np.isscalar(phi) ): aperture = soft_aperture( alpha, phi, semiangle_cutoff, self.angular_sampling ) return aperture else: return hard_aperture(alpha, semiangle_cutoff)
[docs]class Bullseye(BaseAperture): """ Bullseye aperture. Parameters ---------- num_spokes : int Number of spokes. spoke_width : float Width of spokes [deg]. num_rings : int Number of rings. ring_width : float Width of rings [mrad]. semiangle_cutoff : float The cutoff semiangle of the aperture [mrad]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. """
[docs] def __init__( self, num_spokes: int, spoke_width: float, num_rings: int, ring_width: float, semiangle_cutoff: float, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, ): self._spoke_num = num_spokes self._spoke_width = spoke_width self._num_rings = num_rings self._ring_width = ring_width super().__init__( energy=energy, semiangle_cutoff=semiangle_cutoff, extent=extent, gpts=gpts, sampling=sampling, )
@property def num_spokes(self) -> int: """Number of spokes.""" return self._spoke_num @property def spoke_width(self) -> float: """Width of spokes [deg].""" return self._spoke_width @property def num_rings(self) -> int: """Number of rings.""" return self._num_rings @property def ring_width(self) -> float: """Width of rings [mrad].""" return self._ring_width def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) alpha = xp.array(alpha) semiangle_cutoff = self.semiangle_cutoff / 1e3 array = alpha < semiangle_cutoff # add crossbars array = array * ( ((phi + np.pi * self.spoke_width / (180 * 2)) * self.num_spokes) % (2 * np.pi) > (np.pi * self.spoke_width / 180 * self.num_spokes) ) # add ring bars end_edges = np.linspace( semiangle_cutoff / self.num_rings, semiangle_cutoff, self.num_rings ) start_edges = end_edges - self.ring_width / 1e3 for start_edge, end_edge in zip(start_edges, end_edges): array[(alpha > start_edge) * (alpha < end_edge)] = 0.0 return array
[docs]class Vortex(BaseAperture): """ Vortex-beam aperture. Parameters ---------- quantum_number : int Quantum number of vortex beam. semiangle_cutoff : float The cutoff semiangle of the aperture [mrad]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. """
[docs] def __init__( self, quantum_number: int, semiangle_cutoff: float, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, ): self._quantum_number = quantum_number super().__init__( energy=energy, semiangle_cutoff=semiangle_cutoff, extent=extent, gpts=gpts, sampling=sampling, )
@property def quantum_number(self) -> int: """Quantum number of vortex beam.""" return self._quantum_number def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) alpha = xp.array(alpha) semiangle_cutoff = self.semiangle_cutoff / 1e3 array = alpha < semiangle_cutoff array = array * np.exp(1j * phi * self.quantum_number) return array
[docs]class Zernike(BaseAperture): """ Zernike aperture. Parameters ---------- center_hole_cutoff : float Cutoff semiangle of aperture hole [mrad]. phase_shift: float Phase shift of Zernike film [rad] semiangle_cutoff : float The cutoff semiangle of the aperture [mrad]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. """
[docs] def __init__( self, center_hole_cutoff: float, phase_shift: float, semiangle_cutoff: float, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, ): self._center_hole_cutoff = center_hole_cutoff self._phase_shift = phase_shift super().__init__( energy=energy, semiangle_cutoff=semiangle_cutoff, extent=extent, gpts=gpts, sampling=sampling, )
@property def center_hole_cutoff(self) -> float: """Cutoff semiangle of aperture hole.""" return self._center_hole_cutoff @property def phase_shift(self) -> float: """Phase shift of Zernike film.""" return self._phase_shift def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) alpha = xp.array(alpha) semiangle_cutoff = self.semiangle_cutoff / 1e3 center_hole_cutoff = self.center_hole_cutoff / 1e3 phase_shift = self.phase_shift amplitude = xp.asarray(alpha < semiangle_cutoff, dtype=np.float32) phase_array = xp.asarray( xp.logical_and(alpha > center_hole_cutoff, alpha < semiangle_cutoff), dtype=np.float32, ) phase = xp.exp(1.0j * phase_shift * phase_array) array = amplitude * phase return array
[docs]class TemporalEnvelope(BaseTransferFunction): """ Envelope function for simulating partial temporal coherence in the quasi-coherent approximation. Parameters ---------- focal_spread: float or 1D array or BaseDistribution The standard deviation of the focal spread due to chromatic aberration and lens current instability [Å]. Alternatively, a distribution of values may be provided. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. """
[docs] def __init__( self, focal_spread: float | BaseDistribution, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, ): self._accelerator = Accelerator(energy=energy) self._focal_spread = _validate_distribution(focal_spread) super().__init__( distributions=("focal_spread",), energy=energy, extent=extent, gpts=gpts, sampling=sampling, )
@property def focal_spread(self): """The standard deviation of the focal spread [Å].""" return self._focal_spread @focal_spread.setter def focal_spread(self, value): self._focal_spread = _validate_distribution(value) @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: return self._get_axes_metadata_from_distributions( focal_spread={"units": "mrad"} ) def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) unpacked, _ = _unpack_distributions(self.focal_spread, shape=alpha.shape, xp=xp) (focal_spread,) = unpacked alpha = xp.array(alpha) alpha = xp.expand_dims(alpha, axis=tuple(range(0, self._num_ensemble_axes))) array = xp.exp( -((0.5 * xp.pi / self.wavelength * focal_spread * alpha**2) ** 2) ).astype(xp.float32) return array
def _aberration_property(name, key): def _getter(self): try: return getattr(self, name)[key] except KeyError: return 0.0 def _setter(self, value): value = _validate_distribution(value) getattr(self, name)[key] = value return property(_getter, _setter) class _HasAberrations: _aberration_coefficients: dict energy: float C10: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C10" ) C12: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C12" ) phi12: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi12" ) C21: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C21" ) phi21: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi21" ) C23: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C23" ) phi23: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi23" ) C30: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C30" ) C32: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C32" ) phi32: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi32" ) C34: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C34" ) phi34: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi34" ) C41: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C41" ) phi41: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi41" ) C43: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C43" ) phi43: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi43" ) C45: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C45" ) phi45: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi45" ) C50: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C50" ) C52: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C52" ) phi52: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi52" ) C54: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C54" ) phi54: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi54" ) C56: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C56" ) phi56: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi56" ) astigmatism: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C12" ) astigmatism_angle: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi12" ) coma: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C21" ) coma_angle: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "phi21" ) Cs: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C30" ) C5: float | BaseDistribution = _aberration_property( "_aberration_coefficients", "C5" ) @property def defocus(self) -> float | BaseDistribution: """Defocus equivalent to negative C10.""" return -self.C10 @defocus.setter def defocus(self, value: float | BaseDistribution): self.C10 = -value @classmethod def _coefficient_symbols(cls): return tuple(var for var in dir(cls) if re.fullmatch("C[0-9][0-9]", var)) @classmethod def _angular_symbols(cls): return tuple( var for var in dir(_HasAberrations) if re.fullmatch("phi[0-9][0-9]", var) ) def _nonzero_coefficients(self, symbols): for symbol in symbols: if not np.isscalar(self._aberration_coefficients[symbol]): return True if not self._aberration_coefficients[symbol] == 0.0: return True return False @classmethod def _symbols(cls): return cls._coefficient_symbols() + cls._angular_symbols() @staticmethod def _aliases(): return { "defocus": "C10", "Cs": "C30", "C5": "C50", "astigmatism": "C12", "astigmatism_angle": "phi12", "astigmatism3": "C32", "astigmatism3_angle": "phi32", "astigmatism5": "C52", "astigmatism5_angle": "phi52", "coma": "C21", "coma_angle": "phi21", "coma4": "C41", "coma4_angle": "phi41", "trefoil": "C23", "trefoil_angle": "phi23", "trefoil4": "C43", "trefoil4_angle": "phi43", "quadrafoil": "C34", "quadrafoil_angle": "phi34", "quadrafoil5": "C54", "quadrafoil5_angle": "phi54", "pentafoil": "C45", "pentafoil_angle": "phi45", "hexafoil": "C56", "hexafoil_angle": "phi56", } def _default_aberration_coefficients(self): return {symbol: 0.0 for symbol in self._symbols()} @staticmethod def _check_is_valid_aberrations(aberrations): for key in aberrations.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError("{} not a recognized phase aberration".format(key)) @property def _phase_aberrations_ensemble_axes_metadata(self) -> list[AxisMetadata]: axes_metadata = [] for parameter_name, value in self._aberration_coefficients.items(): if isinstance(value, BaseDistribution): m = re.search(r"\d", parameter_name).start() _tex_label = f"${parameter_name[:m]}_{{{parameter_name[m:]}}}$" axes_metadata += [ ParameterAxis( label=parameter_name, values=tuple(value.values), units="Å", _ensemble_mean=value.ensemble_mean, _tex_label=_tex_label, ) ] return axes_metadata @property def aberration_coefficients(self) -> dict[str, float | BaseDistribution]: """The aberration coefficients as a dictionary.""" return copy.deepcopy(self._aberration_coefficients) @property def _has_aberrations(self) -> bool: if np.all( [np.all(value == 0.0) for value in self._aberration_coefficients.values()] ): return False else: return True def set_aberrations( self, aberration_coefficients: dict[str, float | BaseDistribution] ): """ Set the phase of the phase aberration. Parameters ---------- aberration_coefficients : dict Mapping from aberration symbols to their corresponding values. """ for symbol, value in aberration_coefficients.items(): value = _validate_distribution(value) if symbol in self._symbols(): self._aberration_coefficients[symbol] = value elif symbol in self._aliases().keys(): self._aberration_coefficients[self._aliases()[symbol]] = value else: raise ValueError("{} not a recognized parameter".format(symbol)) for symbol, value in aberration_coefficients.items(): if symbol in ("defocus", "C10"): value = _validate_distribution(value) if isinstance(value, str) and value.lower() == "scherzer": if self.energy is None: raise RuntimeError( "energy undefined, Scherzer defocus cannot be evaluated" ) value = scherzer_defocus( self._aberration_coefficients["C30"], self.energy ) if symbol == "defocus": value = -value self._aberration_coefficients["C10"] = value polar_symbols = _HasAberrations._symbols() polar_aliases = _HasAberrations._aliases()
[docs]class SpatialEnvelope(BaseTransferFunction, _HasAberrations): """ Envelope function for simulating partial spatial coherence in the quasi-coherent approximation. Parameters ---------- angular_spread: float or 1D array or BaseDistribution The standard deviation of the angular deviations due to source size [mrad]. Alternatively, a distribution of standard deviations may be provided. aberration_coefficients: dict, optional Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in [Å] and angles should be given in [radian]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. kwargs : dict, optional Optionally provide the aberration coefficients as keyword arguments. """
[docs] def __init__( self, angular_spread: float | BaseDistribution, aberration_coefficients: dict = None, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, **kwargs, ): super().__init__( distributions=polar_symbols + ("angular_spread",), energy=energy, extent=extent, gpts=gpts, sampling=sampling, ) self._angular_spread = _validate_distribution(angular_spread) self._aberration_coefficients = self._default_aberration_coefficients() aberration_coefficients = ( {} if aberration_coefficients is None else aberration_coefficients ) aberration_coefficients = {**aberration_coefficients, **kwargs} self.set_aberrations(aberration_coefficients)
@property def ensemble_axes_metadata(self) -> list[AxisMetadata]: ensemble_axes_metadata = [ *self._phase_aberrations_ensemble_axes_metadata, *self._get_axes_metadata_from_distributions( angular_spread={"units": "mrad"} ), ] return ensemble_axes_metadata @property def angular_spread(self): """The standard deviation of the angular deviations due to source size [mrad].""" return self._angular_spread @angular_spread.setter def angular_spread(self, value): self._angular_spread = value def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) args = tuple(self.aberration_coefficients.values()) + (self.angular_spread,) unpacked, _ = _unpack_distributions(*args, shape=alpha.shape, xp=xp) angular_spread = unpacked[-1] / 1e3 parameters = dict(zip(polar_symbols, unpacked[:-1])) alpha = xp.array(alpha) alpha = xp.expand_dims(alpha, axis=tuple(range(0, self._num_ensemble_axes))) xp = get_array_module(alpha) dchi_dk = ( 2 * xp.pi / self.wavelength * ( ( parameters["C12"] * xp.cos(2.0 * (phi - parameters["phi12"])) + parameters["C10"] ) * alpha + ( parameters["C23"] * xp.cos(3.0 * (phi - parameters["phi23"])) + parameters["C21"] * xp.cos(1.0 * (phi - parameters["phi21"])) ) * alpha**2 + ( parameters["C34"] * xp.cos(4.0 * (phi - parameters["phi34"])) + parameters["C32"] * xp.cos(2.0 * (phi - parameters["phi32"])) + parameters["C30"] ) * alpha**3 + ( parameters["C45"] * xp.cos(5.0 * (phi - parameters["phi45"])) + parameters["C43"] * xp.cos(3.0 * (phi - parameters["phi43"])) + parameters["C41"] * xp.cos(1.0 * (phi - parameters["phi41"])) ) * alpha**4 + ( parameters["C56"] * xp.cos(6.0 * (phi - parameters["phi56"])) + parameters["C54"] * xp.cos(4.0 * (phi - parameters["phi54"])) + parameters["C52"] * xp.cos(2.0 * (phi - parameters["phi52"])) + parameters["C50"] ) * alpha**5 ) ) dchi_dphi = ( -2 * xp.pi / self.wavelength * ( 1 / 2.0 * (2.0 * parameters["C12"] * xp.sin(2.0 * (phi - parameters["phi12"]))) * alpha + 1 / 3.0 * ( 3.0 * parameters["C23"] * xp.sin(3.0 * (phi - parameters["phi23"])) + 1.0 * parameters["C21"] * xp.sin(1.0 * (phi - parameters["phi21"])) ) * alpha**2 + 1 / 4.0 * ( 4.0 * parameters["C34"] * xp.sin(4.0 * (phi - parameters["phi34"])) + 2.0 * parameters["C32"] * xp.sin(2.0 * (phi - parameters["phi32"])) ) * alpha**3 + 1 / 5.0 * ( 5.0 * parameters["C45"] * xp.sin(5.0 * (phi - parameters["phi45"])) + 3.0 * parameters["C43"] * xp.sin(3.0 * (phi - parameters["phi43"])) + 1.0 * parameters["C41"] * xp.sin(1.0 * (phi - parameters["phi41"])) ) * alpha**4 + (1 / 6.0) * ( 6.0 * parameters["C56"] * xp.sin(6.0 * (phi - parameters["phi56"])) + 4.0 * parameters["C54"] * xp.sin(4.0 * (phi - parameters["phi54"])) + 2.0 * parameters["C52"] * xp.sin(2.0 * (phi - parameters["phi52"])) ) * alpha**5 ) ) array = xp.exp( -xp.sign(angular_spread) * (angular_spread / 2) ** 2 * (dchi_dk**2 + dchi_dphi**2) ) return array
[docs]class Aberrations(BaseTransferFunction, _HasAberrations): """ Phase aberrations. Parameters ---------- aberration_coefficients: dict, optional Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in [Å] and angles should be given in [radian]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1/Å]. If 'gpts' is also given, will be ignored. kwargs : dict, optional Optionally provide the aberration coefficients as keyword arguments. """
[docs] def __init__( self, aberration_coefficients: dict[str, float | BaseDistribution] = None, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, **kwargs, ): super().__init__( distributions=polar_symbols, energy=energy, extent=extent, gpts=gpts, sampling=sampling, ) aberration_coefficients = ( {} if aberration_coefficients is None else aberration_coefficients ) aberration_coefficients = {**aberration_coefficients, **kwargs} self._aberration_coefficients = self._default_aberration_coefficients() self.set_aberrations(aberration_coefficients)
@property def ensemble_axes_metadata(self): return self._phase_aberrations_ensemble_axes_metadata @property def defocus(self) -> float | BaseDistribution: """The defocus [Å].""" return -self._aberration_coefficients["C10"] @defocus.setter def defocus(self, value: float | BaseDistribution): self.C10 = -value def _evaluate_from_angular_grid( self, alpha: np.ndarray, phi: np.ndarray ) -> np.ndarray: xp = get_array_module(alpha) if not self._has_aberrations: return xp.ones(self.ensemble_shape + alpha.shape, dtype=xp.complex64) parameters, weights = _unpack_distributions( *tuple(self.aberration_coefficients.values()), shape=alpha.shape, xp=xp ) parameters = dict(zip(polar_symbols, parameters)) axis = tuple(range(0, len(self.ensemble_shape))) alpha = xp.expand_dims(alpha, axis=axis) phi = xp.expand_dims(phi, axis=axis).astype(xp.float32) array = xp.zeros(alpha.shape, dtype=np.float32) if self._nonzero_coefficients(("C10", "C12", "phi12")): array = array + ( 1 / 2 * alpha**2 * ( parameters["C10"] + parameters["C12"] * xp.cos(2 * (phi - parameters["phi12"])) ) ) if self._nonzero_coefficients(("C21", "phi21", "C23", "phi23")): array = array + ( 1 / 3 * alpha**3 * ( parameters["C21"] * xp.cos(phi - parameters["phi21"]) + parameters["C23"] * xp.cos(3 * (phi - parameters["phi23"])) ) ) if self._nonzero_coefficients(("C30", "C32", "phi32", "C34", "phi34")): array = array + ( 1 / 4 * alpha**4 * ( parameters["C30"] + parameters["C32"] * xp.cos(2 * (phi - parameters["phi32"])) + parameters["C34"] * xp.cos(4 * (phi - parameters["phi34"])) ) ) if self._nonzero_coefficients(("C41", "phi41", "C43", "phi43", "C45", "phi45")): array = array + ( 1 / 5 * alpha**5 * ( parameters["C41"] * xp.cos((phi - parameters["phi41"])) + parameters["C43"] * xp.cos(3 * (phi - parameters["phi43"])) + parameters["C45"] * xp.cos(5 * (phi - parameters["phi45"])) ) ) if self._nonzero_coefficients( ("C50", "C52", "phi52", "C54", "phi54", "C56", "phi56") ): array = array + ( 1 / 6 * alpha**6 * ( parameters["C50"] + parameters["C52"] * xp.cos(2 * (phi - parameters["phi52"])) + parameters["C54"] * xp.cos(4 * (phi - parameters["phi54"])) + parameters["C56"] * xp.cos(6 * (phi - parameters["phi56"])) ) ) array *= np.float32(2 * xp.pi / self.wavelength) array = complex_exponential(-array) if cp is not None: weights = cp.asnumpy(weights) if weights is not None: array = xp.asarray(weights, dtype=xp.float32) * array return array
[docs]class CTF(_HasAberrations, BaseAperture): """ The contrast transfer function (CTF) describes the aberrations of the objective lens in HRTEM and specifies how the condenser system shapes the probe in STEM. abTEM implements phase aberrations up to 5th order using polar coefficients. See Eq. 2.22 in the reference [1]_. Cartesian coefficients can be converted to polar using the utility function `abtem.transfer.cartesian2polar`. Partial coherence is included as envelopes in the quasi-coherent approximation. See Chapter 3.2 in reference [1]_. Parameters ---------- semiangle_cutoff: float, optional The semiangle cutoff describes the sharp reciprocal-space cutoff due to the objective aperture [mrad] (default is no cutoff). soft : bool, optional If True, the edge of the aperture is softened (default is True). focal_spread: float, optional The standard deviation of the focal spread due to chromatic aberration and lens current instability [Å] (default is 0). angular_spread: float, optional The standard deviation of the angular deviations due to source size [Å] (default is 0). aberration_coefficients: dict, optional Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in [Å] and angles should be given in [radian]. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. flip_phase : bool, optional Changes the sign of all negative parts of the CTF to positive (following doi:10.1016/j.ultramic.2008.03.004) (default is False). wiener_snr : float, optional Applies a Wiener filter to the CTF (following doi:10.1016/j.ultramic.2008.03.004) with a given SNR value. If no value is given, the default value of 0.0 means that no filter is applied. kwargs : dict, optional Optionally provide the aberration coefficients as keyword arguments. References ---------- .. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.). Springer. """
[docs] def __init__( self, semiangle_cutoff: float | BaseDistribution = np.inf, soft: bool = True, focal_spread: float | BaseDistribution = 0.0, angular_spread: float | BaseDistribution = 0.0, aberration_coefficients: dict[str, float | BaseDistribution] = None, energy: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, flip_phase: bool = False, wiener_snr: float = 0.0, **kwargs, ): super().__init__( distributions=polar_symbols + ( "angular_spread", "focal_spread", "semiangle_cutoff", ), energy=energy, semiangle_cutoff=semiangle_cutoff, extent=extent, gpts=gpts, sampling=sampling, ) self._aberration_coefficients = self._default_aberration_coefficients() aberration_coefficients = ( {} if aberration_coefficients is None else aberration_coefficients ) aberration_coefficients = {**aberration_coefficients, **kwargs} self.set_aberrations(aberration_coefficients) self._angular_spread = _validate_distribution(angular_spread) self._focal_spread = _validate_distribution(focal_spread) self._semiangle_cutoff = _validate_distribution(semiangle_cutoff) self._soft = soft self._flip_phase = flip_phase self._wiener_snr = wiener_snr
@property def aberration_coefficients(self): return self._aberration_coefficients @property def scherzer_defocus(self): """The Scherzer defocus [Å].""" self.accelerator.check_is_defined() if self.Cs == 0.0: raise ValueError() return scherzer_defocus(self.Cs, self.energy) @property def crossover_angle(self): """The first zero-crossing of the phase at Scherzer defocus [mrad].""" return 1e3 * energy2wavelength(self.energy) / self.point_resolution @property def point_resolution(self): """The Scherzer point resolution [Å].""" return point_resolution(self.Cs, self.energy) @property def _aberrations(self): return Aberrations( aberration_coefficients=self.aberration_coefficients, energy=self.energy, extent=self.extent, gpts=self.gpts, ) @property def _aperture(self): return Aperture( semiangle_cutoff=self.semiangle_cutoff, soft=self._soft, energy=self.energy, extent=self.extent, gpts=self.gpts, ) @property def _spatial_envelope(self): return SpatialEnvelope( aberration_coefficients=self.aberration_coefficients, angular_spread=self.angular_spread, energy=self.energy, ) @property def _temporal_envelope(self): return TemporalEnvelope( focal_spread=self.focal_spread, energy=self.energy, ) @property def ensemble_axes_metadata(self): return ( self._spatial_envelope.ensemble_axes_metadata + self._temporal_envelope.ensemble_axes_metadata + self._aperture.ensemble_axes_metadata ) @property def soft(self) -> float: """True if the aperture has a soft edge.""" return self._soft @property def semiangle_cutoff(self) -> float | BaseDistribution: """The semiangle cutoff [mrad].""" return self._semiangle_cutoff @semiangle_cutoff.setter def semiangle_cutoff(self, value: float): self._semiangle_cutoff = value @property def focal_spread(self) -> float | BaseDistribution: """The standard deviation of the focal spread [Å].""" return self._focal_spread @focal_spread.setter def focal_spread(self, value: float): self._focal_spread = value @property def angular_spread(self) -> float | BaseDistribution: """The standard deviation of the angular deviations due to source size [mrad].""" return self._angular_spread @angular_spread.setter def angular_spread(self, value: float): self._angular_spread = value @property def flip_phase(self) -> bool: """If true the signs of all negative parts of the CTF are changed to positive.""" return self._flip_phase @flip_phase.setter def flip_phase(self, value: bool): self._flip_phase = value @property def wiener_snr(self) -> float: """If true a Wiener filter is applied to the CTF.""" return self._wiener_snr @wiener_snr.setter def wiener_snr(self, value: float): self._wiener_snr = value def _evaluate_to_match(self, component, alpha, phi): expanded_axes = () for i, axis_metadata in enumerate(self.ensemble_axes_metadata): expand = all([a != axis_metadata for a in component.ensemble_axes_metadata]) if expand: expanded_axes += (i,) array = component._evaluate_from_angular_grid(alpha, phi) return np.expand_dims(array, expanded_axes) def _evaluate_from_angular_grid(self, alpha, phi, keep_all: bool = False): match_dims = tuple(range(-len(alpha.shape), 0)) array = self._aberrations._evaluate_from_angular_grid(alpha, phi) if self._spatial_envelope.angular_spread != 0.0: new_aberrations_dims = tuple(range(len(self._aberrations.ensemble_shape))) old_match_dims = new_aberrations_dims + match_dims added_dims = int(hasattr(self._spatial_envelope.angular_spread, "values")) new_match_dims = ( tuple(range(len(self._spatial_envelope.ensemble_shape) - added_dims)) + match_dims ) new_array = self._spatial_envelope._evaluate_from_angular_grid(alpha, phi) array, new_array = expand_dims_to_broadcast( array, new_array, match_dims=[old_match_dims, new_match_dims] ) array = array * new_array if self._temporal_envelope.focal_spread != 0.0: new_array = self._temporal_envelope._evaluate_from_angular_grid(alpha, phi) array, new_array = expand_dims_to_broadcast( array, new_array, match_dims=2 * [match_dims] ) array = array * new_array if self._aperture.semiangle_cutoff != np.inf: new_array = self._aperture._evaluate_from_angular_grid(alpha, phi) array, new_array = expand_dims_to_broadcast( array, new_array, match_dims=2 * [match_dims] ) array = array * new_array if self._wiener_snr != 0.0: return ( (1 + 1 / self._wiener_snr) * array**2 / (array**2 + 1 / self._wiener_snr) ) elif self._flip_phase: return array.real - 1j * np.abs(array.imag) else: return array def to_point_spread_functions( self, gpts: int | tuple[int, int], extent: float | tuple[float, float] ): from abtem.waves import Probe return ( Probe(gpts=gpts, extent=extent, energy=self.energy, aperture=self) .build() .complex_images() )
[docs] def profiles(self, gpts: int = 1000, max_angle: float = None, phi: float = 0.0): """ Calculate radial line profiles for each included component (phase aberrations, aperture, temporal and spatial envelopes) of the contrast transfer function. Parameters ---------- gpts: int Number of grid points along the line profiles. max_angle : float The maximum scattering angle included in the radial line profiles [mrad]. The default is 1.5 times the semiangle cutoff or 50 mrad if no semiangle cutoff is set. phi : float The azimuthal angle of the radial line profiles [rad]. Default is 0. Returns ------- ctf_profiles : ReciprocalSpaceLineProfiles Ensemble of reciprocal space line profiles. The first ensemble dimension represents the different """ if max_angle is None: if self.semiangle_cutoff == np.inf: max_angle = 50 else: max_angle = self._max_semiangle_cutoff * 1.6 self.accelerator.check_is_defined() sampling = max_angle / (gpts - 1) / (self.wavelength * 1e3) alpha = np.linspace(0, max_angle * 1e-3, gpts).astype(np.float32) components = dict() components["ctf"] = self._evaluate_to_match(self._aberrations, alpha, phi).imag if self._spatial_envelope.angular_spread != 0.0: components["spatial_envelope"] = self._evaluate_to_match( self._spatial_envelope, alpha, phi ) if self._temporal_envelope.focal_spread != 0.0: components["temporal_envelope"] = self._evaluate_to_match( self._temporal_envelope, alpha, phi ) if self._aperture.semiangle_cutoff != np.inf: components["aperture"] = self._evaluate_to_match(self._aperture, alpha, phi) components["ctf"] = reduce(lambda x, y: x * y, tuple(components.values())) ensemble_axes_metadata = self.ensemble_axes_metadata if len(components) > 1: profiles = np.stack( np.broadcast_arrays(*list(components.values())), axis=-2, ) component_metadata = [ OrdinalAxis( label="", values=tuple(components.keys()), _default_type="overlay" ) ] ensemble_axes_metadata = ensemble_axes_metadata + component_metadata else: profiles = components["ctf"] metadata = {"energy": self.energy} profiles = ReciprocalSpaceLineProfiles( profiles, sampling=sampling, metadata=metadata, ensemble_axes_metadata=ensemble_axes_metadata, ) return profiles
[docs]def nyquist_sampling(semiangle_cutoff: float, energy: float) -> float: """ Calculate the Nyquist sampling. Parameters ---------- semiangle_cutoff: float Semiangle cutoff [mrad]. energy: float Electron energy [eV]. """ wavelength = energy2wavelength(energy) return 1 / (4 * semiangle_cutoff / wavelength * 1e-3)
[docs]def scherzer_defocus(Cs: float, energy: float) -> float: """ Calculate the Scherzer defocus. Parameters ---------- Cs: float Spherical aberration [Å]. energy: float Electron energy [eV]. """ return np.sign(Cs) * np.sqrt(3 / 2 * np.abs(Cs) * energy2wavelength(energy))
[docs]def point_resolution(Cs: float, energy: float) -> float: """ Calculate the Scherzer point resolution. Parameters ---------- Cs: float Spherical aberration [Å]. energy: float Electron energy [eV]. """ return (energy2wavelength(energy) ** 3 * np.abs(Cs) / 6) ** (1 / 4)
[docs]def polar2cartesian(polar): """ Convert between polar and Cartesian aberration coefficients. Parameters ---------- polar : dict Mapping from polar aberration symbols to their corresponding values. Returns ------- cartesian : dict Mapping from Cartesian aberration symbols to their corresponding values. """ polar = defaultdict(lambda: 0, polar) cartesian = dict() cartesian["C10"] = polar["C10"] cartesian["C12a"] = -polar["C12"] * np.cos(2 * polar["phi12"]) cartesian["C12b"] = polar["C12"] * np.sin(2 * polar["phi12"]) cartesian["C21a"] = polar["C21"] * np.sin(polar["phi21"]) cartesian["C21b"] = polar["C21"] * np.cos(polar["phi21"]) cartesian["C23a"] = -polar["C23"] * np.sin(3 * polar["phi23"]) cartesian["C23b"] = polar["C23"] * np.cos(3 * polar["phi23"]) cartesian["C30"] = polar["C30"] cartesian["C32a"] = -polar["C32"] * np.cos(2 * polar["phi32"]) cartesian["C32b"] = polar["C32"] * np.cos(np.pi / 2 - 2 * polar["phi32"]) cartesian["C34a"] = polar["C34"] * np.cos(-4 * polar["phi34"]) k = np.sqrt(3 + np.sqrt(8.0)) cartesian["C34b"] = ( 1 / 4.0 * (1 + k**2) ** 2 / (k**3 - k) * polar["C34"] * np.cos(4 * np.arctan(1 / k) - 4 * polar["phi34"]) ) return cartesian
[docs]def cartesian2polar(cartesian): """ Convert between Cartesian and polar aberration coefficients. Parameters ---------- cartesian : dict Mapping from Cartesian aberration symbols to their corresponding values. Returns ------- polar : dict Mapping from polar aberration symbols to their corresponding values. """ cartesian = defaultdict(lambda: 0, cartesian) polar = dict() polar["C10"] = cartesian["C10"] polar["C12"] = -np.sqrt(cartesian["C12a"] ** 2 + cartesian["C12b"] ** 2) polar["phi12"] = -np.arctan2(cartesian["C12b"], cartesian["C12a"]) / 2.0 polar["C21"] = np.sqrt(cartesian["C21a"] ** 2 + cartesian["C21b"] ** 2) polar["phi21"] = np.arctan2(cartesian["C21a"], cartesian["C21b"]) polar["C23"] = np.sqrt(cartesian["C23a"] ** 2 + cartesian["C23b"] ** 2) polar["phi23"] = -np.arctan2(cartesian["C23a"], cartesian["C23b"]) / 3.0 polar["C30"] = cartesian["C30"] polar["C32"] = -np.sqrt(cartesian["C32a"] ** 2 + cartesian["C32b"] ** 2) polar["phi32"] = -np.arctan2(cartesian["C32b"], cartesian["C32a"]) / 2.0 polar["C34"] = np.sqrt(cartesian["C34a"] ** 2 + cartesian["C34b"] ** 2) polar["phi34"] = np.arctan2(cartesian["C34b"], cartesian["C34a"]) / 4 return polar