Source code for abtem.detect

"""Module for describing the detection of transmitted waves and different detector types."""
from abc import ABCMeta, abstractmethod
from copy import copy
from typing import Tuple, List, Any, Union, Sequence

import numpy as np

from abtem.base_classes import Cache, Event, watched_property, cached_method
from abtem.device import get_array_module, get_device_function
from abtem.measure import Calibration, calibrations_from_grid, Measurement
from abtem.scan import AbstractScan
from abtem.utils import spatial_frequencies
from abtem.visualize.mpl import show_measurement_2d


def _polar_regions(gpts: Tuple[int, int], angular_sampling: Tuple[float, float], inner: float, outer: float,
                   nbins_radial: int, nbins_azimuthal: int, rotation=0.):
    """
    Create an array of labels for the regions of a given detector geometry.

    Parameters
    ----------
    gpts : two int
        Number of grid points describing the detector regions.
    angular_sampling : two float
        Angular sampling of the discretized dete ctor regions in radians.
    inner : float
        Inner boundary of the detector regions [rad].
    outer : float
        Outer boundary of the detector regions [rad].
    nbins_radial : int
        Number of radial detector bins.
    nbins_azimuthal
        Number of azimuthal detector bins.

    Returns
    -------
    2d array
        Array of integer labels representing the detector regions.
    """
    """Create the polar segmentation of a detector."""

    sampling = (1 / angular_sampling[0] / gpts[0], 1 / angular_sampling[1] / gpts[1])
    kx, ky = spatial_frequencies(gpts, sampling)

    alpha_x = np.asarray(kx)
    alpha_y = np.asarray(ky)

    alpha = np.sqrt(alpha_x.reshape((-1, 1)) ** 2 + alpha_y.reshape((1, -1)) ** 2)

    radial_bins = -np.ones(gpts, dtype=int)
    valid = (alpha >= inner) & (alpha <= outer)

    radial_bins[valid] = nbins_radial * (alpha[valid] - inner) / (outer - inner)

    angles = (np.arctan2(alpha_x[:, None], alpha_y[None]) + rotation) % (2 * np.pi)

    angular_bins = np.floor(nbins_azimuthal * (angles / (2 * np.pi)))
    angular_bins = np.clip(angular_bins, 0, nbins_azimuthal - 1).astype(np.int)

    bins = -np.ones(gpts, dtype=int)
    bins[valid] = angular_bins[valid] + radial_bins[valid] * nbins_azimuthal
    return bins


def check_max_angle_exceeded(waves, max_angle):
    if (max_angle is not None) and (not isinstance(max_angle, str)):
        if max_angle > min(waves.cutoff_scattering_angles):
            raise RuntimeError('Detector max angle exceeds the cutoff scattering angle.')


[docs]class AbstractDetector(metaclass=ABCMeta): """Abstract base class for all detectors.""" def __init__(self, max_detected_angle=None, save_file: str = None): if save_file is not None: save_file = str(save_file) if not save_file.endswith('.hdf5'): self._save_file = save_file + '.hdf5' else: self._save_file = save_file else: self._save_file = None self._max_detected_angle = max_detected_angle @property def save_file(self) -> str: """The path to the file for saving the detector output.""" return self._save_file @abstractmethod def detect(self, waves) -> Any: pass @abstractmethod def allocate_measurement(self, waves, scan) -> Measurement: pass
class _PolarDetector(AbstractDetector): """Class to define a polar detector, forming the basis of annular and segmented detectors.""" def __init__(self, inner: float = None, outer: float = None, radial_steps: float = 1., azimuthal_steps: float = None, offset: Tuple[float, float] = None, rotation: float = 0., save_file: str = None): self._inner = inner self._outer = outer self._radial_steps = radial_steps if azimuthal_steps is None: azimuthal_steps = 2 * np.pi self._azimuthal_steps = azimuthal_steps self._rotation = rotation self._offset = offset self.cache = Cache(1) self.changed = Event() super().__init__(max_detected_angle=outer, save_file=save_file) @classmethod def _label_to_index(cls, labels): xp = get_array_module(labels) labels = labels.flatten() labels_order = labels.argsort() sorted_labels = labels[labels_order] indices = xp.arange(0, len(labels) + 1)[labels_order] index = xp.arange(0, np.max(labels) + 1) lo = xp.searchsorted(sorted_labels, index, side='left') hi = xp.searchsorted(sorted_labels, index, side='right') for i, (l, h) in enumerate(zip(lo, hi)): yield indices[l:h] def _get_bins(self, cutoff_scattering_angle=None): if self._inner is None: inner = 0. else: inner = self._inner if self._outer is None: if cutoff_scattering_angle is None: raise RuntimeError('The outer integration angle is not set.') outer = cutoff_scattering_angle outer = np.floor(outer / self._radial_steps) * self._radial_steps else: outer = self._outer nbins_radial = int(np.ceil((outer - inner) / self._radial_steps)) nbins_azimuthal = int(np.ceil(2 * np.pi / self._azimuthal_steps)) return nbins_radial, nbins_azimuthal, inner, outer @cached_method('cache') def _get_regions(self, gpts: Tuple[int, int], angular_sampling: Tuple[float, float], cutoff_scattering_angle: float = None, xp=np) -> List[np.ndarray]: nbins_radial, nbins_azimuthal, inner, outer = self._get_bins(cutoff_scattering_angle) region_labels = _polar_regions(gpts, (angular_sampling[0] / 1e3, angular_sampling[1] / 1e3), inner / 1e3, outer / 1e3, nbins_radial, nbins_azimuthal, rotation=self._rotation) if self._offset is not None: offset = (int(round(self._offset[0] / angular_sampling[0])), int(round(self._offset[1] / angular_sampling[1]))) if (abs(offset[0]) > region_labels.shape[0]) or (abs(offset[1]) > region_labels.shape[1]): raise RuntimeError('Detector offset exceeds maximum detected angle.') region_labels = np.roll(region_labels, offset, (0, 1)) region_labels = xp.asarray(region_labels) if np.all(region_labels == -1): raise RuntimeError('Zero-sized detector region.') region_indices = [] for indices in self._label_to_index(region_labels): region_indices.append(indices) return region_indices def allocate_measurement(self, waves, scan: AbstractScan = None) -> Measurement: """ Allocate a Measurement object or an hdf5 file. Parameters ---------- waves : Waves object An example of the scan : Scan object The scan object that will define the scan dimensions the measurement. Returns ------- Measurement object or str The allocated measurement or path to hdf5 file with the measurement data. """ waves.grid.check_is_defined() waves.accelerator.check_is_defined() if scan is None: shape = () calibrations = () else: shape = scan.shape calibrations = scan.calibrations nbins_radial, nbins_azimuthal, inner, _ = self._get_bins(min(waves.cutoff_scattering_angles)) if nbins_radial > 1: shape += (nbins_radial,) calibrations += (Calibration(offset=inner, sampling=self._radial_steps, units='mrad'),) if nbins_azimuthal > 1: shape += (nbins_azimuthal,) calibrations += (Calibration(offset=0, sampling=self._azimuthal_steps, units='rad'),) array = np.zeros(shape, dtype=np.float32) measurement = Measurement(array, calibrations=calibrations) if isinstance(self.save_file, str): measurement = measurement.write(self.save_file) return measurement def show(self, waves, **kwargs): """ Visualize the detector region(s) of the detector as applied to a specified wave function. Parameters ---------- waves : Waves or SMatrix object The wave function the visualization will be created to match kwargs : Additional keyword arguments for abtem.visualize.mpl.show_measurement_2d. """ waves.grid.check_is_defined() array = np.full(waves.gpts, -1, dtype=np.int) for i, indices in enumerate(self._get_regions(waves.gpts, waves.angular_sampling, min(waves.cutoff_scattering_angles))): array.ravel()[indices] = i calibrations = calibrations_from_grid(waves.gpts, waves.sampling, names=['alpha_x', 'alpha_y'], units='mrad', scale_factor=waves.wavelength * 1e3, fourier_space=True) array = np.fft.fftshift(array, axes=(-1, -2)) measurement = Measurement(array, calibrations=calibrations, name='Detector regions') return show_measurement_2d(measurement, discrete_cmap=True, **kwargs)
[docs]class AnnularDetector(_PolarDetector): """ Annular detector object. The annular detector integrates the intensity of the detected wave functions between an inner and outer integration limit. Parameters ---------- inner: float Inner integration limit [mrad]. outer: float Outer integration limit [mrad]. offset: two float, optional Center offset of integration region [mrad]. save_file: str, optional The path to the file for saving the detector output. """ def __init__(self, inner: float, outer: float, offset: Tuple[float, float] = None, save_file: str = None): super().__init__(inner=inner, outer=outer, offset=offset, radial_steps=outer - inner, save_file=save_file) @property def inner(self) -> float: """Inner integration limit [mrad].""" return self._inner @inner.setter @watched_property('changed') def inner(self, value: float): self._inner = value @property def outer(self) -> float: """Outer integration limit [mrad].""" return self._outer @outer.setter @watched_property('changed') def outer(self, value: float): self._max_detected_angle = value self._outer = value def _integrate_array(self, array: np.ndarray, angular_sampling: Tuple[float, float], cutoff_scattering_angle: float = None): xp = get_array_module(array) indices = self._get_regions(array.shape[-2:], angular_sampling, cutoff_scattering_angle)[0] indexed = array.reshape(array.shape[:-2] + (-1,))[..., indices] values = xp.sum(indexed, axis=-1) return values
[docs] def integrate(self, diffraction_patterns: Measurement) -> Measurement: """ Integrate diffraction pattern measurements on the detector region. Parameters ---------- diffraction_patterns : 2d, 3d or 4d Measurement object The collection diffraction patterns to be integrated. Returns ------- Measurement """ if diffraction_patterns.dimensions < 2: raise ValueError() if not (diffraction_patterns.calibrations[-1].units == diffraction_patterns.calibrations[-2].units): raise ValueError() sampling = (diffraction_patterns.calibrations[-2].sampling, diffraction_patterns.calibrations[-1].sampling) calibrations = diffraction_patterns.calibrations[:-2] array = np.fft.ifftshift(diffraction_patterns.array, axes=(-2, -1)) cutoff_scattering_angle = min(diffraction_patterns.calibrations[-2].sampling * (diffraction_patterns.array.shape[-2] // 2), diffraction_patterns.calibrations[-1].sampling * (diffraction_patterns.array.shape[-1] // 2), ) if cutoff_scattering_angle < self.outer: raise RuntimeError('Outer integration limit exceeds the maximum measurement scattering angle ' f'({cutoff_scattering_angle} mrad)') return Measurement(self._integrate_array(array, sampling, cutoff_scattering_angle), calibrations=calibrations)
[docs] def detect(self, waves) -> np.ndarray: """ Integrate the intensity of a the wave functions over the detector range. Parameters ---------- waves : Waves object The batch of wave functions to detect. Returns ------- 1d array Detected values as a 1D array. The array has the same length as the batch size of the wave functions. """ xp = get_array_module(waves.array) fft2 = get_device_function(xp, 'fft2') abs2 = get_device_function(xp, 'abs2') intensity = abs2(fft2(waves.array, overwrite_x=False)) return self._integrate_array(intensity, waves.angular_sampling, min(waves.cutoff_scattering_angles))
def __copy__(self) -> 'AnnularDetector': return self.__class__(self.inner, self.outer, save_file=self.save_file)
[docs] def copy(self) -> 'AnnularDetector': """Make a copy.""" return copy(self)
[docs]class FlexibleAnnularDetector(_PolarDetector): """ Flexible annular detector object. The FlexibleAnnularDetector object allows choosing the integration limits after running the simulation by radially binning the intensity. Parameters ---------- step_size: float The radial separation between integration regions [mrad]. save_file: str The path to the file used for saving the detector output. """ def __init__(self, step_size: float = 1., save_file: str = None): super().__init__(radial_steps=step_size, save_file=save_file) @property def step_size(self) -> float: """ Step size [mrad]. """ return self._radial_steps @step_size.setter @watched_property('changed') def step_size(self, value: float): self._radial_steps = value
[docs] def detect(self, waves) -> np.ndarray: """ Integrate the intensity of a the wave functions over the detector range. Parameters ---------- waves: Waves object The batch of wave functions to detect. Returns ------- 2d array Detected values. The array has shape of (batch size, number of bins). """ xp = get_array_module(waves.array) fft2 = get_device_function(xp, 'fft2') abs2 = get_device_function(xp, 'abs2') sum_run_length_encoded = get_device_function(xp, 'sum_run_length_encoded') intensity = abs2(fft2(waves.array, overwrite_x=False)) indices = self._get_regions(waves.gpts, waves.angular_sampling, min(waves.cutoff_scattering_angles), xp) separators = xp.concatenate((xp.array([0]), xp.cumsum(xp.array([len(ring) for ring in indices])))) intensity = intensity.reshape((intensity.shape[0], -1))[:, xp.concatenate(indices)] result = xp.zeros((len(intensity), len(indices)), dtype=xp.float32) sum_run_length_encoded(intensity, result, separators) return result
def __copy__(self) -> 'FlexibleAnnularDetector': return self.__class__(self.step_size, save_file=self.save_file)
[docs] def copy(self) -> 'FlexibleAnnularDetector': """ Make a copy. """ return copy(self)
[docs]class SegmentedDetector(_PolarDetector): """ Segmented detector object. The segmented detector covers an annular angular range, and is partitioned into several integration regions divided to radial and angular segments. This can be used for simulating differential phase contrast (DPC) imaging. Parameters ---------- inner: float Inner integration limit [mrad]. outer: float Outer integration limit [mrad]. nbins_radial: int Number of radial bins. nbins_angular: int Number of angular bins. save_file: str The path to the file used for saving the detector output. """ def __init__(self, inner: float, outer: float, nbins_radial: int, nbins_angular: int, rotation: float = 0., save_file: str = None): radial_steps = (outer - inner) / nbins_radial azimuthal_steps = 2 * np.pi / nbins_angular super().__init__(inner=inner, outer=outer, radial_steps=radial_steps, azimuthal_steps=azimuthal_steps, rotation=rotation, save_file=save_file) @property def inner(self) -> float: """Inner integration limit [mrad].""" return self._inner @inner.setter @watched_property('changed') def inner(self, value: float): self._inner = value @property def outer(self) -> float: """Outer integration limit [mrad].""" return self._outer @outer.setter @watched_property('changed') def outer(self, value: float): self._outer = value @property def nbins_radial(self) -> int: """Number of radial bins.""" return int((self.outer - self.inner) / self._radial_steps) @nbins_radial.setter @watched_property('changed') def nbins_radial(self, value: int): self._radial_steps = (self.outer - self.inner) / value @property def nbins_angular(self) -> int: """Number of angular bins.""" return int(2 * np.pi / self._azimuthal_steps) @nbins_angular.setter @watched_property('changed') def nbins_angular(self, value: float): self._azimuthal_steps = 2 * np.pi / value
[docs] def detect(self, waves) -> np.ndarray: """ Integrate the intensity of a the wave functions over the detector range. Parameters ---------- waves: Waves object The batch of wave functions to detect. Returns ------- 3d array Detected values. The first dimension indexes the batch size, the second and third indexes the radial and angular bins, respectively. """ xp = get_array_module(waves.array) fft2 = get_device_function(xp, 'fft2') abs2 = get_device_function(xp, 'abs2') sum_run_length_encoded = get_device_function(xp, 'sum_run_length_encoded') intensity = abs2(fft2(waves.array, overwrite_x=False)) indices = self._get_regions(waves.gpts, waves.angular_sampling, min(waves.cutoff_scattering_angles), xp) separators = xp.concatenate((xp.array([0]), xp.cumsum(xp.array([len(ring) for ring in indices])))) intensity = intensity.reshape((intensity.shape[0], -1))[:, xp.concatenate(indices)] result = xp.zeros((len(intensity), len(separators) - 1), dtype=xp.float32) sum_run_length_encoded(intensity, result, separators) shape = (-1,) if self.nbins_radial > 1: shape += (self.nbins_radial,) if self.nbins_angular > 1: shape += (self.nbins_angular,) return result.reshape(shape)
def __copy__(self) -> 'SegmentedDetector': return self.__class__(inner=self.inner, outer=self.outer, nbins_radial=self.nbins_radial, nbins_angular=self.nbins_angular, save_file=self.save_file)
[docs] def copy(self) -> 'SegmentedDetector': """Make a copy.""" return copy(self)
[docs]class PixelatedDetector(AbstractDetector): """ Pixelated detector object. The pixelated detector records the intensity of the Fourier-transformed exit wavefunction. This may be used for example for simulating 4D-STEM. Parameters ---------- max_angle : str or float or None The diffraction patterns will be detected up to this angle. If set to a string it must be 'limit' or 'valid' resample : 'uniform' or False If 'uniform', the diffraction patterns from rectangular cells will be downsampled to a uniform angular sampling. mode : 'intensity' or 'complex' save_file : str The path to the file used for saving the detector output. """ def __init__(self, max_angle: Union[str, float] = 'valid', resample: Union[str, float] = False, mode='intensity', save_file: str = None): self._max_angle = max_angle self._resample = resample self._mode = mode super().__init__(save_file=save_file) @property def max_angle(self): return self._max_angle @property def resample(self): return self._resample def _bilinear_nodes_and_weight(self, old_shape, new_shape, old_angular_sampling, new_angular_sampling, xp): nodes = [] weights = [] old_sampling = (1 / old_angular_sampling[0] / old_shape[0], 1 / old_angular_sampling[1] / old_shape[1]) new_sampling = (1 / new_angular_sampling[0] / new_shape[0], 1 / new_angular_sampling[1] / new_shape[1]) for n, m, r, d in zip(old_shape, new_shape, old_sampling, new_sampling): k = xp.fft.fftshift(xp.fft.fftfreq(n, r).astype(xp.float32)) k_new = xp.fft.fftshift(xp.fft.fftfreq(m, d).astype(xp.float32)) distances = k_new[None] - k[:, None] distances[distances < 0.] = np.inf w = distances.min(0) / (k[1] - k[0]) w[w == np.inf] = 0. nodes.append(distances.argmin(0)) weights.append(w) v, u = nodes vw, uw = weights v, u, vw, uw = xp.broadcast_arrays(v[:, None], u[None, :], vw[:, None], uw[None, :]) return v, u, vw, uw def _resampled_gpts(self, gpts, angular_sampling): if self._resample is False: return gpts, angular_sampling if self._resample == 'uniform': scale_factor = (angular_sampling[0] / max(angular_sampling), angular_sampling[1] / max(angular_sampling)) else: scale_factor = (angular_sampling[0] / self._resample[0], angular_sampling[1] / self._resample[1]) new_gpts = (int(np.ceil(gpts[0] * scale_factor[0])), int(np.ceil(gpts[1] * scale_factor[1]))) if np.abs(new_gpts[0] - new_gpts[1]) <= 2: new_gpts = (min(new_gpts),) * 2 new_angular_sampling = (angular_sampling[0] / scale_factor[0], angular_sampling[1] / scale_factor[1]) return new_gpts, new_angular_sampling def _interpolate(self, array, angular_sampling): xp = get_array_module(array) interpolate_bilinear = get_device_function(xp, 'interpolate_bilinear') new_gpts, new_angular_sampling = self._resampled_gpts(array.shape[-2:], angular_sampling) v, u, vw, uw = self._bilinear_nodes_and_weight(array.shape[-2:], new_gpts, angular_sampling, new_angular_sampling, xp) return interpolate_bilinear(array, v, u, vw, uw)
[docs] def allocate_measurement(self, waves, scan: AbstractScan = None) -> Measurement: """ Allocate a Measurement object or an hdf5 file. Parameters ---------- waves : Waves or SMatrix object The wave function that will define the shape of the diffraction patterns. scan: Scan object The scan object that will define the scan dimensions the measurement. Returns ------- Measurement object or str The allocated measurement or path to hdf5 file with the measurement data. """ waves.grid.check_is_defined() waves.accelerator.check_is_defined() check_max_angle_exceeded(waves, self.max_angle) gpts = waves.downsampled_gpts(self.max_angle) gpts, new_angular_sampling = self._resampled_gpts(gpts, angular_sampling=waves.angular_sampling) sampling = (1 / new_angular_sampling[0] / gpts[0] * waves.wavelength * 1000, 1 / new_angular_sampling[1] / gpts[1] * waves.wavelength * 1000) calibrations = calibrations_from_grid(gpts, sampling, names=['alpha_x', 'alpha_y'], units='mrad', scale_factor=waves.wavelength * 1000, fourier_space=True) if scan is None: scan_shape = () scan_calibrations = () elif isinstance(scan, tuple): scan_shape = scan scan_calibrations = (None,) * len(scan) else: scan_shape = scan.shape scan_calibrations = scan.calibrations if self._mode == 'intensity': array = np.zeros(scan_shape + gpts, dtype=np.float32) elif self._mode == 'complex': array = np.zeros(scan_shape + gpts, dtype=np.complex64) else: raise ValueError() measurement = Measurement(array, calibrations=scan_calibrations + calibrations) if isinstance(self.save_file, str): measurement = measurement.write(self.save_file) return measurement
[docs] def detect(self, waves) -> np.ndarray: """ Calculate the far field intensity of the wave functions. The output is cropped to include the non-suppressed frequencies from the antialiased 2D fourier spectrum. Parameters ---------- waves: Waves object The batch of wave functions to detect. Returns ------- Detected values. The first dimension indexes the batch size, the second and third indexes the two components of the spatial frequency. """ xp = get_array_module(waves.array) abs2 = get_device_function(xp, 'abs2') waves = waves.far_field(max_angle=self.max_angle) if self._mode == 'intensity': array = abs2(waves.array) elif self._mode == 'complex': array = waves.array else: raise ValueError() array = xp.fft.fftshift(array, axes=(-2, -1)) if self._resample: array = self._interpolate(array, waves.angular_sampling) return array
[docs]class WavefunctionDetector(AbstractDetector): """ Wave function detector object The wave function detector records the raw exit wave functions. Parameters ---------- save_file: str The path to the file used for saving the detector output. """ def __init__(self, save_file: str = None): super().__init__(max_detected_angle=None, save_file=save_file)
[docs] def allocate_measurement(self, waves, scan: AbstractScan) -> Measurement: """ Allocate a Measurement object or an hdf5 file. Parameters ---------- waves : Waves or SMatrix object The wave function that will define the shape of the diffraction patterns. scan: Scan object The scan object that will define the scan dimensions the measurement. Returns ------- Measurement object or str The allocated measurement or path to hdf5 file with the measurement data. """ waves.grid.check_is_defined() calibrations = calibrations_from_grid(waves.gpts, waves.sampling, names=['x', 'y'], units='Å') array = np.zeros(scan.shape + waves.gpts, dtype=np.complex64) measurement = Measurement(array, calibrations=scan.calibrations + calibrations) if isinstance(self.save_file, str): measurement = measurement.write(self.save_file) return measurement
[docs] def detect(self, waves) -> np.ndarray: """ Detect the complex wave function. Parameters ---------- waves: Waves object The batch of wave functions to detect. Returns ------- 3d complex array The arrays of the Waves object. """ return waves.array