Source code for abtem.inelastic.phonons

"""Module to describe the effect of temperature on the atomic positions."""
from __future__ import annotations

from abc import abstractmethod, ABCMeta
from functools import partial
from numbers import Number
from typing import Sequence, Iterable

import dask
import dask.array as da
import numpy as np
from ase import Atoms
from ase import data
from ase.cell import Cell
from ase.data import chemical_symbols
from ase.io import read
from ase.io.trajectory import read_atoms
from dask.delayed import Delayed

from abtem.core.axes import FrozenPhononsAxis, AxisMetadata, UnknownAxis
from abtem.core.chunks import chunk_ranges, validate_chunks
from abtem.core.ensemble import Ensemble, _wrap_with_array, unpack_blockwise_args
from abtem.core.utils import CopyMixin, EqualityMixin

try:
    from gpaw.io import Reader  # noqa
except ImportError:
    Reader = None


def _safe_read_atoms(calculator, clean: bool = True) -> Atoms:
    if isinstance(calculator, str):
        with Reader(calculator) as reader:
            atoms = read_atoms(reader.atoms)
    else:
        atoms = calculator.atoms

    if clean:
        atoms.constraints = None
        atoms.calc = True

    return atoms


[docs] class BaseFrozenPhonons(Ensemble, EqualityMixin, CopyMixin, metaclass=ABCMeta): """Base class for frozen phonons."""
[docs] def __init__( self, atomic_numbers: np.ndarray, cell: Cell, ensemble_mean: bool = True ): self._cell = cell self._atomic_numbers = atomic_numbers self._ensemble_mean = ensemble_mean
@property def ensemble_mean(self): """The mean of the ensemble of results from a multislice simulation is calculated.""" return self._ensemble_mean @property def atomic_numbers(self) -> np.ndarray: """The unique atomic number of the atoms.""" return self._atomic_numbers @property def cell(self) -> Cell: """The cell of the atoms.""" return self._cell @staticmethod def _validate_atomic_numbers_and_cell(atoms: Atoms, atomic_numbers, cell): if isinstance(atoms, da.core.Array) and ( atomic_numbers is None or cell is None ): atoms = atoms.compute() if cell is None: cell = atoms.cell.copy() else: if not isinstance(cell, Cell): cell = Cell(cell) if not np.allclose(atoms.cell.array, cell.array): raise RuntimeError("cell of provided Atoms did not match provided cell") if atomic_numbers is None: atomic_numbers = np.unique(atoms.numbers) else: atomic_numbers = np.array(atomic_numbers, dtype=int) return atomic_numbers, cell @property @abstractmethod def atoms(self) -> Atoms: """Base atomic configuration used for displacements.""" pass
[docs] @abstractmethod def randomize(self, atoms: Atoms) -> Atoms: """ Randomize the atoms. Parameters ---------- atoms : Atoms """ pass
@abstractmethod def __len__(self) -> int: pass @property @abstractmethod def num_configs(self): """Number of atomic configurations.""" pass def __iter__(self): for _, _, fp in self.generate_blocks(1): fp = fp.item() yield fp.randomize(fp.atoms)
[docs] class DummyFrozenPhonons(BaseFrozenPhonons): """Class to allow all potentials to be treated in the same way."""
[docs] def __init__( self, atoms: Atoms, num_configs: int = None, ): self._atoms = atoms self._num_configs = num_configs atomic_numbers, cell = self._validate_atomic_numbers_and_cell(atoms, None, None) super().__init__(atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=True)
@property def num_configs(self): return self._num_configs @property def ensemble_shape(self): if self._num_configs is None: return () else: return (self._num_configs,) @property def _default_ensemble_chunks(self): if self._num_configs is None: return () else: return (1,) @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: if self._num_configs is None: return [] else: return [FrozenPhononsAxis(_ensemble_mean=self.ensemble_mean)]
[docs] def randomize(self, atoms: Atoms) -> Atoms: return atoms
@property def atoms(self): return self._atoms @classmethod def _from_partitioned_args_func(cls, args, **kwargs): if hasattr(args, "item"): args = args.item() atoms = args new_dummy_frozen_phonons = cls(atoms=atoms, **kwargs) return _wrap_with_array(new_dummy_frozen_phonons, 1) def _from_partitioned_args(self): kwargs = self._copy_kwargs(exclude=("atoms",)) return partial(self._from_partitioned_args_func, **kwargs) def _partition_args(self, chunks: int = 1, lazy: bool = True): if lazy: lazy_args = dask.delayed(_wrap_with_array)(self.atoms, ndims=1) array = da.from_delayed(lazy_args, shape=(1,), dtype=object) else: atoms = self.atoms array = _wrap_with_array(atoms, 1) return (array,) def __len__(self): if self._num_configs is None: return 1 else: return self._num_configs
def _validate_seeds( seeds: int | tuple[int, ...] | None, num_seeds: int = None ) -> tuple[int, ...]: if seeds is None or np.isscalar(seeds): if num_seeds is None: raise ValueError("Provide `num_configs` or a seed for each configuration.") rng = np.random.default_rng(seed=seeds) seeds = () while len(seeds) < num_seeds: seed = rng.integers(np.iinfo(np.int32).max) if seed not in seeds: seeds += (seed,) else: if not hasattr(seeds, "__len__"): raise ValueError if num_seeds is not None: assert num_seeds == len(seeds) return seeds
[docs] class FrozenPhonons(BaseFrozenPhonons): """ The frozen phonons randomly displace the atomic positions to emulate thermal vibrations. Parameters ---------- atoms : ASE.Atoms Atomic configuration used for displacements. num_configs : int Number of frozen phonon configurations. sigmas : float or dict or list If float, the standard deviation of the displacements is assumed to be identical for all atoms. If dict, a displacement standard deviation should be provided for each species. The atomic species can be specified as atomic number or a symbol, using the ASE standard. If list or array, a displacement standard deviation should be provided for each atom. Anistropic displacements may be given by providing a standard deviation for each principal direction. This may be a tuple of three numbers for identical displacements for all atoms. A dict of tuples of three numbers to specify displacements for each species. A list or array with three numbers for each atom. directions : str, optional The displacement directions of the atoms as a string; for example 'xy' (default) for displacement in the `x`- and `y`-direction (i.e. perpendicular to the propagation direction). ensemble_mean : bool, optional If True (default), the mean of the ensemble of results from a multislice simulation is calculated, otherwise, the result of every frozen phonon configuration is returned. seed : int or sequence of int Seed(s) for the random number generator used to generate the displacements, or one seed for each configuration in the frozen phonon ensemble. """
[docs] def __init__( self, atoms: Atoms, num_configs: int, sigmas: float | dict[str | int, float] | Sequence[float], directions: str = "xyz", ensemble_mean: bool = True, seed: int | tuple[int, ...] = None, ): if isinstance(sigmas, dict): atomic_numbers = [data.atomic_numbers[symbol] for symbol in sigmas.keys()] else: atomic_numbers = None atomic_numbers, cell = self._validate_atomic_numbers_and_cell( atoms, atomic_numbers, cell=None ) self._sigmas = sigmas self._directions = directions self._atoms = atoms self._seed = _validate_seeds(seed, num_seeds=num_configs) super().__init__( atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=ensemble_mean )
def _validate_sigmas(self, atoms): unique_symbols = [chemical_symbols[number] for number in self.atomic_numbers] sigmas = self._sigmas if isinstance(sigmas, Number): new_sigmas = {} for symbol in unique_symbols: new_sigmas[symbol] = sigmas anisotropic = False sigmas = new_sigmas elif isinstance(sigmas, dict): anisotropic = any(hasattr(value, "__len__") for value in sigmas.values()) if anisotropic and not all(len(value) == 3 for value in sigmas.values()): raise RuntimeError("Three values for each element must be given for anisotropic displacements.") if not all([symbol in unique_symbols for symbol in sigmas.keys()]): raise RuntimeError( "Displacement standard deviation must be provided for all atomic species." ) elif isinstance(sigmas, Iterable): sigmas = np.array(sigmas, dtype=np.float32) if len(sigmas) != len(atoms): raise RuntimeError( "Displacement standard deviation must be provided for all atoms." ) if len(sigmas.shape) == 2: if sigmas.shape[1] == 3: anisotropic = True else: raise RuntimeError( "Three values for each atom must be given for anisotropic displacements." ) elif len(sigmas.shape) == 1: anisotropic = False else: raise RuntimeError() else: raise ValueError() return sigmas, anisotropic @property def ensemble_shape(self): return (len(self),) @property def _default_ensemble_chunks(self): return (1,) @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: return [FrozenPhononsAxis(_ensemble_mean=self.ensemble_mean)] @property def num_configs(self) -> int: return len(self._seed) @property def seed(self) -> tuple[int, ...]: """Random seed for each displacement configuration.""" return self._seed @property def sigmas(self) -> float | dict[str | int, float] | Sequence[float]: """Displacement standard deviation for each atom.""" return self._sigmas @property def atoms(self) -> Atoms: return self._atoms @property def directions(self) -> str: """The directions of the random displacements.""" return self._directions def __len__(self) -> int: return self.num_configs @property def _axes(self) -> list[int]: axes = [] for direction in list(set(self._directions.lower())): if direction == "x": axes += [0] elif direction == "y": axes += [1] elif direction == "z": axes += [2] else: raise RuntimeError(f"Directions must be 'x', 'y' or 'z', not {axes}.") return axes
[docs] def randomize(self, atoms: Atoms) -> Atoms: sigmas, anisotropic = self._validate_sigmas(atoms) if isinstance(sigmas, dict): if anisotropic: temp = np.zeros((len(atoms.numbers), 3), dtype=np.float32) else: temp = np.zeros(len(atoms.numbers), dtype=np.float32) for unique in np.unique(atoms.numbers): temp[atoms.numbers == unique] = np.float32( sigmas[chemical_symbols[unique]] ) sigmas = temp elif not isinstance(sigmas, np.ndarray): raise RuntimeError() atoms = atoms.copy() rng = np.random.default_rng(self.seed[0]) if anisotropic: r = rng.normal(size=(len(atoms), 3)) for axis in self._axes: atoms.positions[:, axis] += sigmas[:, axis] * r[:, axis] else: r = rng.normal(size=(len(atoms), 3)) / np.sqrt(3) for axis in self._axes: atoms.positions[:, axis] += sigmas * r[:, axis] return atoms
@classmethod def _from_partitioned_args_func(cls, *args, **kwargs): args = unpack_blockwise_args(args) atoms, seed = args[0] new = cls(atoms=atoms, seed=seed, num_configs=len(seed), **kwargs) new = _wrap_with_array(new, len(new.ensemble_shape)) return new def _from_partitioned_args(self): kwargs = self._copy_kwargs(exclude=("atoms", "seed", "num_configs")) output = partial(self._from_partitioned_args_func, **kwargs) return output def _partition_args(self, chunks: int = 1, lazy: bool = True): chunks = validate_chunks(self.ensemble_shape, chunks) if lazy: arrays = [] for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]): seeds = self.seed[start:stop] lazy_atoms = dask.delayed(self.atoms) lazy_args = dask.delayed(_wrap_with_array)((lazy_atoms, seeds), ndims=1) lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object) arrays.append(lazy_array) array = da.concatenate(arrays) else: atoms = self.atoms array = np.zeros((len(chunks[0]),), dtype=object) for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]): array.itemset(i, (atoms, self.seed[start:stop])) return (array,)
[docs] def to_atoms_ensemble(self): """ Convert the frozen phonons to an ensemble of atoms. Returns ------- atoms_ensemble : AtomsEnsemble """ trajectory = [] for _, _, block in self.generate_blocks(1): block = block.item() trajectory.append(block.randomize(block.atoms)) return AtomsEnsemble(trajectory)
[docs] class AtomsEnsemble(BaseFrozenPhonons): """ Frozen phonons based on a molecular dynamics simulation. Parameters ---------- trajectory : list of ASE.Atoms, dask.core.Array, list of dask.Delayed Sequence of atoms representing a distribution of atomic configurations. ensemble_mean : True, optional If True, the mean of the ensemble of results from a multislice simulation is calculated, otherwise, the result of every frozen phonon is returned. ensemble_axes_metadata : list of AxesMetadata, optional Axis metadata for each ensemble axis. The axis metadata must be compatible with the shape of the array. cell : Cell, optional """
[docs] def __init__( self, trajectory: Sequence[Atoms], ensemble_mean: bool = True, ensemble_axes_metadata: list[AxisMetadata] = None, cell: Cell = None, ): if isinstance(trajectory, str): trajectory = read(trajectory, index=":") elif isinstance(trajectory, Atoms): trajectory = [trajectory] if isinstance(trajectory, (list, tuple)): if isinstance(trajectory[0], str): trajectory = [_safe_read_atoms(path) for path in trajectory] if isinstance(trajectory[0], Delayed): stack = [] for atoms in trajectory: atoms = dask.delayed(_wrap_with_array)(atoms, 1) atoms = da.from_delayed(atoms, shape=(1,), dtype=object) stack.append(atoms) trajectory = da.concatenate(stack) else: stack = np.empty(len(trajectory), dtype=object) for i, atoms in enumerate(trajectory): stack.itemset(i, atoms) trajectory = stack assert isinstance(trajectory, (np.ndarray, da.core.Array)) if ensemble_axes_metadata is None: ensemble_axes_metadata = [FrozenPhononsAxis(_ensemble_mean=ensemble_mean)] elif isinstance(ensemble_axes_metadata, AxisMetadata): ensemble_axes_metadata = [ensemble_axes_metadata] elif not isinstance(ensemble_axes_metadata, list): raise ValueError() assert len(ensemble_axes_metadata) == len(trajectory.shape) atoms = trajectory.ravel()[0] atomic_numbers, cell = self._validate_atomic_numbers_and_cell(atoms, None, cell) self._trajectory = trajectory super().__init__( atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=ensemble_mean ) self._ensemble_axes_metadata = ensemble_axes_metadata
@property def trajectory(self) -> np.ndarray | da.core.Array: """Array of atoms representing an ensemble of atomic configurations.""" return self._trajectory def __getitem__(self, item): new_trajectory = self._trajectory[item] kwargs = self._copy_kwargs(exclude=("trajectory",)) return AtomsEnsemble(new_trajectory, **kwargs) @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: return self._ensemble_axes_metadata def __len__(self) -> int: return len(self._trajectory) @property def num_configs(self) -> int: return len(self._trajectory) @property def atoms(self) -> Atoms: return self._trajectory.ravel()[0] @property def ensemble_shape(self) -> tuple[int, ...]: if isinstance(self._trajectory, (da.core.Array, np.ndarray)): return self._trajectory.shape return (len(self),) @property def _default_ensemble_chunks(self) -> tuple[int, ...]: if isinstance(self._trajectory, (da.core.Array, np.ndarray)): return (1,) * len(self.ensemble_shape) return (1,) def _partition_args(self, chunks: int = 1, lazy: bool = True): chunks = validate_chunks(self.ensemble_shape, chunks) if lazy: arrays = [] for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]): trajectory = self.trajectory[start:stop] lazy_args = dask.delayed(_wrap_with_array)(trajectory, ndims=1) lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object) arrays.append(lazy_array) array = da.concatenate(arrays) else: trajectory = self.trajectory if isinstance(trajectory, da.core.Array): trajectory = trajectory.compute() array = np.zeros((len(chunks[0]),), dtype=object) for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]): array.itemset(i, _wrap_with_array(trajectory[start:stop], 1)) return (array,) @staticmethod def _from_partition_args_func(*args, **kwargs): args = unpack_blockwise_args(args) trajectory = args[0] atoms_ensemble = AtomsEnsemble(trajectory, **kwargs) return _wrap_with_array(atoms_ensemble, 1) def _from_partitioned_args(self): kwargs = self._copy_kwargs(exclude=("trajectory", "ensemble_shape")) kwargs["cell"] = self.cell.array kwargs["ensemble_axes_metadata"] = [UnknownAxis()] * len(self.ensemble_shape) return partial(self._from_partition_args_func, **kwargs)
[docs] def randomize(self, atoms: Atoms) -> Atoms: return atoms
[docs] def standard_deviations(self) -> np.ndarray: """ Standard deviation of the positions of each atom in each direction. """ ensemble_positions = [atoms.positions for atoms in self._trajectory] num_atoms = len(ensemble_positions[0]) if not all(len(positions) == num_atoms for positions in ensemble_positions): raise RuntimeError() mean_positions = np.mean(ensemble_positions, axis=0) squared_deviations = [ (atoms.positions - mean_positions) ** 2 for atoms in self._trajectory ] return np.sqrt(np.sum(squared_deviations, axis=0) / (len(self) - 1))