"""Module for plotting atoms, images, line scans, and diffraction patterns."""
from __future__ import annotations
import string
from abc import abstractmethod, ABCMeta
from collections import defaultdict
from typing import TYPE_CHECKING, Sequence, Iterable
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from ase import Atoms
from ase.data import covalent_radii, chemical_symbols
from ase.data.colors import jmol_colors
from matplotlib import colors
from matplotlib.animation import FuncAnimation
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection, EllipseCollection
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
from matplotlib.offsetbox import AnchoredText
from matplotlib.patches import Circle
from mpl_toolkits.axes_grid1 import Size, Divider
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.axes_grid import _cbaraxes_class_factory
from traitlets.traitlets import link
from abtem.atoms import pad_atoms, plane_to_axes
from abtem.core import config
from abtem.core.axes import ReciprocalSpaceAxis, format_label, LinearAxis
from abtem.core.colors import hsluv_cmap
from abtem.core.units import _get_conversion_factor
from abtem.core.utils import label_to_index
try:
import ipywidgets as widgets
except ImportError:
widgets = None
ipywidgets_not_installed = RuntimeError(
"This functionality of abTEM requires ipywidgets, see "
"https://ipywidgets.readthedocs.io/en/stable/user_install.html."
)
if TYPE_CHECKING:
from abtem.measurements import (
BaseMeasurements,
_BaseMeasurement1D,
_BaseMeasurement2D,
IndexedDiffractionPatterns,
)
def _make_default_sizes():
sizes = {
"cbar_padding_left": Size.Fixed(0.15),
"cbar_spacing": Size.Fixed(0.4),
"cbar_padding_right": Size.Fixed(0.9),
"padding": Size.Fixed(0.1),
}
return sizes
def _cbar_layout(n, sizes):
if n == 0:
return []
layout = [sizes["cbar_padding_left"]]
for i in range(n):
layout.extend([sizes["cbar"]])
if i < n - 1:
layout.extend([sizes["cbar_spacing"]])
layout.extend([sizes["cbar_padding_right"]])
return layout
def _make_grid_layout(
axes, ncbars: int, sizes: dict, cbar_mode: str = "each", direction: str = "col"
):
sizes_layout = []
if cbar_mode not in ("single", "each"):
raise ValueError()
for i, ax in enumerate(axes):
if direction == "col":
sizes_layout.append(Size.AxesX(ax, aspect="axes", ref_ax=axes[0]))
elif direction == "row":
sizes_layout.append(Size.AxesY(ax, aspect="axes", ref_ax=axes[0]))
else:
raise ValueError()
if not "cbar" in sizes:
sizes["cbar"] = Size.from_any("5%", sizes_layout[0])
if cbar_mode == "each":
sizes_layout.extend(_cbar_layout(ncbars, sizes))
if i < len(axes) - 1:
sizes_layout.append(sizes["padding"])
if cbar_mode == "single":
sizes_layout.extend(_cbar_layout(ncbars, sizes))
return sizes_layout
[docs]class AxesGrid:
[docs] def __init__(
self,
fig,
ncols: int,
nrows: int,
ncbars: int = 0,
cbar_mode: str = "single",
aspect: bool = True,
sharex: bool = True,
sharey: bool = True,
rect: tuple = (0.1, 0.1, 0.85, 0.85),
col_sizes: dict = None,
row_sizes: dict = None,
):
from mpl_toolkits.axes_grid1.mpl_axes import Axes
self._ncols = ncols
self._nrows = nrows
self._ncbars = ncbars
self._aspect = aspect
self._sharex = sharex
self._sharey = sharey
if col_sizes is None:
col_sizes = _make_default_sizes()
if row_sizes is None:
row_sizes = _make_default_sizes()
self._col_sizes = col_sizes
self._row_sizes = row_sizes
axes = []
for nx in range(ncols):
for ny in range(nrows):
if len(axes) > 0:
if sharex:
sharex = axes[0]
else:
sharex = None
if sharey:
sharey = axes[0]
else:
sharey = None
ax = Axes(fig, rect, sharex=sharex, sharey=sharey)
else:
ax = Axes(fig, rect, sharex=None, sharey=None)
axes.append(ax)
for ax in axes:
fig.add_axes(ax)
cols = np.array(axes, dtype=object).reshape((ncols, nrows))[:, 0]
rows = np.array(axes, dtype=object).reshape((ncols, nrows))[0]
col_layout = _make_grid_layout(
cols,
ncbars=ncbars,
sizes=self._col_sizes,
cbar_mode=cbar_mode,
direction="col",
)
row_layout = _make_grid_layout(
rows, ncbars=0, sizes=self._row_sizes, direction="row"
)
self._divider = Divider(
fig, rect, horizontal=col_layout, vertical=row_layout, aspect=aspect
)
axes_index = 0
caxes_index = 0
if cbar_mode == "single":
caxes = {axes[0]: []}
else:
caxes = {ax: [] for ax in axes}
for nx, col_size in enumerate(col_layout):
for ny, row_size in enumerate(row_layout):
if isinstance(col_size, Size.AxesX) and (
isinstance(row_size, Size.AxesY)
):
ax = axes[axes_index]
ax.set_axes_locator(self._divider.new_locator(nx=nx, ny=ny))
axes_index += 1
if (
(cbar_mode == "each")
and (col_size is self._col_sizes["cbar"])
and (isinstance(row_size, Size.AxesY))
):
ax = axes[
np.ravel_multi_index(
(caxes_index // (ncbars * nrows), caxes_index % nrows),
(ncols, nrows),
)
]
caxes_index += 1
cb_ax = _cbaraxes_class_factory(Axes)(
fig, self._divider.get_position(), orientation="vertical"
)
fig.add_axes(cb_ax)
cb_ax.set_axes_locator(self._divider.new_locator(nx=nx, ny=ny))
caxes[ax].append(cb_ax)
if (
(cbar_mode == "single")
and (len(caxes[axes[0]]) < ncbars)
and (col_size is self._col_sizes["cbar"])
and (isinstance(row_size, Size.AxesY))
):
for i in range(ncbars):
cb_ax = _cbaraxes_class_factory(Axes)(
fig, self._divider.get_position(), orientation="vertical"
)
fig.add_axes(cb_ax)
cb_ax.set_axes_locator(
self._divider.new_locator(nx=nx + i * 2, ny=0, ny1=-1)
)
caxes[axes[0]].append(cb_ax)
axes = np.array(axes, dtype=object).reshape((ncols, nrows))
if sharex:
for inner_axes in axes[:, 1:]:
for ax in inner_axes:
ax._axislines["bottom"].toggle(ticklabels=False, label=False)
if sharey:
for inner_axes in axes[1:]:
for ax in inner_axes:
ax._axislines["left"].toggle(ticklabels=False, label=False)
self._axes = axes
self._caxes = caxes
@property
def divider(self):
return self._divider
@property
def ncols(self) -> int:
return self._axes.shape[0]
@property
def nrows(self) -> int:
return self._axes.shape[1]
def __getitem__(self, item):
return self._axes[item]
def __len__(self):
return len(self._axes)
@property
def shape(self) -> tuple[int, int]:
return self._axes.shape
def set_cbar_padding(self, padding: tuple[float, float] = (0.1, 0.1)):
if np.isscalar(padding):
padding = (padding,) * 2
self._col_sizes["cbar_padding_left"].fixed_size = padding[0]
self._row_sizes["cbar_padding_left"].fixed_size = padding[0]
self._col_sizes["cbar_padding_right"].fixed_size = padding[1]
self._row_sizes["cbar_padding_right"].fixed_size = padding[1]
def set_cbar_size(self, fraction: float):
self._col_sizes["cbar"]._fraction = fraction
self._row_sizes["cbar"]._fraction = fraction
def set_cbar_spacing(self, spacing: float):
self._col_sizes["cbar_spacing"].fixed_size = spacing
self._row_sizes["cbar_spacing"].fixed_size = spacing
def set_axes_padding(self, padding: float | tuple[float, float] = (0.0, 0.0)):
if np.isscalar(padding):
padding = (padding,) * 2
self._col_sizes["padding"].fixed_size = padding[0]
self._row_sizes["padding"].fixed_size = padding[1]
def _axes_grid_cols_and_rows(measurements, axes_types):
shape = measurements.ensemble_shape
shape = tuple(
n
for n, axes_type in zip(shape, axes_types)
if not axes_type in ("index", "range", "overlay")
)
if len(shape) > 0:
ncols = shape[0]
else:
ncols = 1
if len(shape) > 1:
nrows = shape[1]
else:
nrows = 1
return ncols, nrows
def _determine_axes_types(
measurements: BaseMeasurements,
explode: bool | tuple[bool, ...] | None,
overlay: bool | tuple[bool, ...] | None,
):
num_ensemble_axes = len(measurements.ensemble_shape)
axes_types = []
for axis_metadata in measurements.ensemble_axes_metadata:
if axis_metadata._default_type is not None:
axes_types.append(axis_metadata._default_type)
else:
axes_types.append("index")
if explode is True:
explode = tuple(range(max(num_ensemble_axes - 2, 0), num_ensemble_axes))
elif explode is False:
explode = ()
if overlay is True:
overlay = tuple(range(max(num_ensemble_axes - 2, 0), num_ensemble_axes))
elif overlay is False:
overlay = ()
axes_types = list(axes_types)
for i, axis_type in enumerate(axes_types):
if explode is not None:
if i in explode:
axes_types[i] = "explode"
else:
axes_types[i] = "index"
if overlay is not None:
if i in overlay:
axes_types[i] = "overlay"
elif i not in explode:
axes_types[i] = "index"
return axes_types
def _validate_axes(
measurements: BaseMeasurements,
ax: Axes,
explode: bool = False,
overlay: bool = False,
cbar: bool = False,
common_color_scale: bool = False,
figsize: tuple[float, float] = None,
ioff: bool = False,
aspect: bool = True,
sharex: bool = True,
sharey: bool = True,
):
axes_types = _determine_axes_types(measurements, explode, overlay)
if cbar:
if measurements.is_complex:
ncbars = 2
else:
ncbars = 1
else:
ncbars = 0
if common_color_scale:
cbar_mode = "single"
else:
cbar_mode = "each"
if ax is None:
if ioff:
with plt.ioff():
fig = plt.figure(figsize=figsize)
else:
fig = plt.figure(figsize=figsize)
if ax is None: # and ("explode" in axes_types):
ncols, nrows = _axes_grid_cols_and_rows(measurements, axes_types)
axes = AxesGrid(
fig=fig,
ncols=ncols,
nrows=nrows,
ncbars=ncbars,
cbar_mode=cbar_mode,
aspect=aspect,
sharex=sharex,
sharey=sharey,
)
# elif ax is None:
# ax = fig.add_subplot()
# axes = np.array([[ax]])
else:
if explode:
raise NotImplementedError("`ax` not implemented with `explode = True`.")
axes = np.array([[ax]])
return axes
def _format_options(options):
formatted_options = []
for option in options:
if isinstance(option, float):
formatted_options.append(f"{option:.3f}")
elif isinstance(option, tuple):
formatted_options.append(
", ".join(tuple(f"{value:.3f}" for value in option))
)
else:
formatted_options.append(option)
return formatted_options
[docs]def discrete_cmap(num_colors, base_cmap):
if isinstance(base_cmap, str):
base_cmap = plt.get_cmap(base_cmap)
colors = base_cmap(range(0, num_colors))
return matplotlib.colors.LinearSegmentedColormap.from_list("", colors, num_colors)
# def _check_ensemble_axes_equal(masurements):
[docs]def make_sliders_from_ensemble_axes(
visualizations: MeasurementVisualization | Sequence[MeasurementVisualization],
axes_types: tuple[str, ...],
continuous_update: bool = False,
callbacks: tuple[callable, ...] = (),
):
if not isinstance(visualizations, Sequence):
visualizations = [visualizations]
ensemble_axes_metadata = visualizations[0].measurements.ensemble_axes_metadata
ensemble_shape = visualizations[0].measurements.ensemble_shape
for visualization in visualizations[1:]:
if not isinstance(visualization, MeasurementVisualization):
raise ValueError()
if not (
(
visualization.measurements.ensemble_axes_metadata
== ensemble_axes_metadata
)
and (visualization.measurements.ensemble_shape == ensemble_shape)
):
raise ValueError()
sliders = []
for axes_metadata, n, axes_type in zip(
ensemble_axes_metadata,
ensemble_shape,
axes_types,
):
options = _format_options(axes_metadata.coordinates(n))
with config.set({"visualize.use_tex": False}):
label = axes_metadata.format_label()
if axes_type == "range":
sliders.append(
widgets.SelectionRangeSlider(
description=label,
options=options,
continuous_update=continuous_update,
index=(0, len(options) - 1),
)
)
elif axes_type == "index":
sliders.append(
widgets.SelectionSlider(
description=label,
options=options,
continuous_update=continuous_update,
)
)
for visualization in visualizations:
_set_update_indices_callback(sliders, visualization, callbacks)
return sliders
def _set_update_indices_callback(sliders, visualization, callbacks=()):
def update_indices(change):
indices = ()
for slider in sliders:
idx = slider.index
if isinstance(idx, tuple):
idx = slice(*idx)
indices += (idx,)
with sliders[0].hold_trait_notifications():
visualization.set_ensemble_indices(indices)
if visualization._autoscale:
vmin, vmax = visualization.get_global_vmin_vmax()
visualization._update_vmin_vmax(vmin, vmax)
for slider in sliders:
slider.observe(update_indices, "value")
for callback in callbacks:
slider.observe(callback, "value")
def _make_continuous_button(sliders):
continuous_update = config.get("visualize.continuous_update", False)
continuous_update_checkbox = widgets.ToggleButton(
description="Continuous update", value=continuous_update
)
for slider in sliders:
link((continuous_update_checkbox, "value"), (slider, "continuous_update"))
return continuous_update_checkbox
def _get_max_range(array, axes_types):
if np.iscomplexobj(array):
array = np.abs(array)
max_values = array.max(
tuple(
i for i, axes_type in enumerate(axes_types) if axes_type not in ("range",)
)
)
positive_indices = np.where(max_values > 0)[0]
if len(positive_indices) <= 1:
max_value = np.max(max_values)
else:
max_value = np.sum(max_values[positive_indices])
return max_value
def _make_vmin_vmax_slider(visualization):
axes_types = (
tuple(visualization._axes_types)
+ ("base",) * visualization.measurements.num_base_axes
)
max_value = _get_max_range(visualization.measurements.array, axes_types)
min_value = -_get_max_range(-visualization.measurements.array, axes_types)
step = (max_value - min_value) / 1e6
vmin_vmax_slider = widgets.FloatRangeSlider(
value=visualization._get_vmin_vmax(),
min=min_value,
max=max_value,
step=step,
disabled=visualization._autoscale,
description="Normalization",
# readout=False,
continuous_update=True,
)
def vmin_vmax_slider_changed(change):
vmin, vmax = change["new"]
vmax = max(vmax, vmin + step)
with vmin_vmax_slider.hold_trait_notifications():
visualization._update_vmin_vmax(vmin, vmax)
vmin_vmax_slider.observe(vmin_vmax_slider_changed, "value")
return vmin_vmax_slider
def _make_scale_button(visualization):
scale_button = widgets.Button(description="Scale")
def scale_button_clicked(*args):
vmin, vmax = visualization.get_global_vmin_vmax()
visualization._update_vmin_vmax(vmin, vmax)
scale_button.on_click(scale_button_clicked)
return scale_button
def _make_autoscale_button(visualization):
def autoscale_button_changed(change):
if change["new"]:
visualization._autoscale = True
else:
visualization._autoscale = False
autoscale_button = widgets.ToggleButton(
value=visualization._autoscale,
description="Autoscale",
tooltip="Autoscale",
)
autoscale_button.observe(autoscale_button_changed, "value")
return autoscale_button
def _make_power_scale_slider(visualization):
def powerscale_slider_changed(change):
visualization._update_power(change["new"])
power_scale_slider = widgets.FloatSlider(
value=visualization._get_power(),
min=0.01,
max=2,
step=0.01,
description="Power",
tooltip="Power",
)
power_scale_slider.observe(powerscale_slider_changed, "value")
return power_scale_slider
def _get_joined_titles(measurement, formatting, **kwargs):
titles = []
for axes_metadata in measurement.ensemble_axes_metadata:
titles.append(axes_metadata.format_title(formatting, **kwargs))
return "\n".join(titles)
[docs]class MeasurementVisualization(metaclass=ABCMeta):
[docs] def __init__(
self,
measurements: BaseMeasurements,
axes: AxesGrid | np.ndarray,
axes_types: Sequence[str] = (),
autoscale: bool = False,
):
self._measurements = measurements.to_cpu()
self._axes = axes
self._axes_types = axes_types
self._indices = self._validate_ensemble_indices()
self._column_titles = []
self._row_titles = []
self._panel_labels = []
self._metadata_labels = np.array([])
self._xunits = None
self._yunits = None
self._autoscale = autoscale
for ax in np.array(self.axes).ravel():
ax.ticklabel_format(
style="sci", scilimits=(-3, 3), axis="both", useMathText=True
)
self.fig.canvas.header_visible = False
@property
def autoscale(self):
return self._autoscale
@autoscale.setter
def autoscale(self, value):
self._autoscale = value
@property
def fig(self):
return self._axes[0, 0].get_figure()
@property
@abstractmethod
def artists(self):
pass
def adjust_figure_aspect(self):
bbox = self.fig.get_tightbbox()
aspect = (bbox.ymax - bbox.ymin) / (bbox.xmax - bbox.xmin)
size = self.fig.get_size_inches()
self.fig.set_size_inches((size[0], size[0] * aspect))
def adjust_axes_position(self, rect):
self.axes.divider.set_position(rect)
def _generate_measurements(self, keepdims: bool = True):
indexed_measurements = self._get_indexed_measurements()
shape = tuple(
n if axes_type != "overlay" else 1
for n, axes_type in zip(
indexed_measurements.ensemble_shape, self._axes_types
)
)
for indices in np.ndindex(*shape):
axes_index = ()
for i, axes_type in zip(indices, self._axes_types):
if axes_type == "explode":
axes_index += (i,)
axes_index = (axes_index + (0,) * (2 - len(axes_index)))[:2]
indices = tuple(
i if axes_type != "overlay" else slice(None)
for i, axes_type in zip(indices, self._axes_types)
)
yield axes_index, indexed_measurements.get_items(indices, keepdims=keepdims)
[docs] def set_axes_padding(self, padding: float | tuple[float, float] = (0.0, 0.0)):
"""
Set the padding between the axes in an :class:`.AxesGrid`.
Parameters
----------
padding : float or tuple of float
The padding along columns and rows.
"""
self._axes.set_axes_padding(padding)
def _get_axes_from_axes_types(self, axes_type):
return tuple(
i
for i, checked_axes_type in enumerate(self.axes_types)
if checked_axes_type == axes_type
)
def _get_indexed_measurements(self, keepdims: bool = True):
indexed = self.measurements.get_items(self._indices, keepdims=keepdims)
if keepdims:
summed_axes = tuple(
i
for i, axes_type in enumerate(self._axes_types)
if axes_type == "range"
)
else:
i = 0
summed_axes = ()
for axes_type in self._axes_types:
if axes_type == "range":
summed_axes += (i,)
i += 1
indexed = indexed.sum(axis=summed_axes, keepdims=keepdims)
return indexed
def set_column_titles(
self,
titles: str | list[str] = None,
pad: float = 10.0,
format: str = ".3g",
units: str = None,
fontsize=12,
**kwargs,
):
indexed_measurements = self._get_indexed_measurements(keepdims=False)
if titles is None or titles is True:
if not len(indexed_measurements.ensemble_shape):
return
# TODO: same for row titles
j = 0
for j, axes_type in enumerate(self.axes_types):
if not axes_type == "overlay":
break
axes_metadata = indexed_measurements.ensemble_axes_metadata[j]
if hasattr(axes_metadata, "to_nonlinear_axis"):
axes_metadata = axes_metadata.to_nonlinear_axis(
indexed_measurements.ensemble_shape[j]
)
titles = []
for i, axis_metadata in enumerate(axes_metadata):
titles.append(
axis_metadata.format_title(
format, units=units, include_label=i == 0
)
)
if i == indexed_measurements.ensemble_shape[j]:
break
elif isinstance(titles, str):
if indexed_measurements.ensemble_shape:
n = indexed_measurements.ensemble_shape[0]
else:
n = 1
titles = [titles] * n
for column_title in self._column_titles:
column_title.remove()
column_titles = []
for i, ax in enumerate(self.axes[:, -1]):
annotation = ax.annotate(
titles[i],
xy=(0.5, 1),
xytext=(0, pad),
xycoords="axes fraction",
textcoords="offset points",
ha="center",
va="baseline",
fontsize=fontsize,
**kwargs,
)
column_titles.append(annotation)
self._column_titles = column_titles
def set_xlim(self, *args, **kwargs):
for ax in np.array(self.axes).ravel():
ax.set_xlim(args, **kwargs)
def set_ylim(self, *args, **kwargs):
for ax in np.array(self.axes).ravel():
ax.set_ylim(args, **kwargs)
@abstractmethod
def _get_default_xlabel(self, units=None):
pass
@abstractmethod
def _get_default_ylabel(self, units=None):
pass
def set_xlabels(self, label: str = None):
if label is None:
label = self._get_default_xlabel(units=self._xunits)
for i, j in np.ndindex(self.axes.shape): # noqa
if j == 0:
ax = self.axes[i, j]
ax.set_xlabel(label)
def set_ylabels(self, label: str = None):
if label is None:
label = self._get_default_ylabel(units=self._yunits)
for i, j in np.ndindex(self.axes.shape): # noqa
if i == 0:
ax = self.axes[i, j]
ax.set_ylabel(label)
@abstractmethod
def set_xlim(self):
pass
@abstractmethod
def set_ylim(self):
pass
@abstractmethod
def _get_default_xunits(self):
pass
@abstractmethod
def _get_default_yunits(self):
pass
[docs] def set_xunits(self, units: str = None):
"""
Set the units for the x-axis.
Parameters
----------
units : str
The name of the units. Must be compatible with existing units.
"""
if units is None:
self._xunits = self._get_default_xunits()
else:
self._xunits = units
self.set_xlabels()
self.set_xlim()
[docs] def set_yunits(self, units: str = None):
"""
Set the units for the y-axis.
Parameters
----------
units : str
The name of the units. Must be compatible with existing units.
"""
if units is None:
self._yunits = self._get_default_yunits()
else:
self._yunits = units
self.set_ylabels()
self.set_ylim()
[docs] def set_row_titles(
self,
titles: str | list[str] = None,
shift: float = 0.0,
format: str = ".3g",
units: str = None,
**kwargs,
):
"""
Set the titles for the rows of the grid of axes.
Parameters
----------
titles : str or list of str, optional
If given as list, each item is given as a title for a row, the list must have the same length as the number
of rows. If given as string the same title is given to all rows. If not given the titles are derived from
the axes metadata.
shift : float, optional
Horizontal shift of the title positions.
format : str, optional
String formatting of titles derived from axes metadata.
units : str, optional
The units used for titles derived from axes metadata.
"""
indexed_measurements = self._get_indexed_measurements()
if not "fontsize" in kwargs:
kwargs.update({"fontsize": 12})
if titles is None:
if not len(indexed_measurements.ensemble_shape) > 1:
return
axes_metadata = indexed_measurements.ensemble_axes_metadata[1]
if hasattr(axes_metadata, "to_nonlinear_axis"):
axes_metadata = axes_metadata.to_nonlinear_axis(
indexed_measurements.ensemble_shape[1]
)
titles = []
for i, axis_metadata in enumerate(axes_metadata):
titles.append(
axis_metadata.format_title(
format, units=units, include_label=i == 0
)
)
if i == indexed_measurements.ensemble_shape[1]:
break
elif isinstance(titles, str):
titles = [titles] * max(len(indexed_measurements.ensemble_shape), 1)
for row_title in self._row_titles:
row_title.remove()
row_titles = []
for i, ax in enumerate(self.axes[0, :]):
annotation = ax.annotate(
titles[i],
xy=(0, 0.5),
xytext=(-ax.yaxis.labelpad - shift, 0),
xycoords=ax.yaxis.label,
textcoords="offset points",
ha="right",
va="center",
rotation=90,
**kwargs,
)
row_titles.append(annotation)
self._row_titles = row_titles
@property
def ncols(self):
return self._axes.shape[0]
@property
def nrows(self):
return self._axes.shape[1]
@property
def axes_types(self):
return self._axes_types
@property
def indices(self):
return self._indices
@property
def measurements(self):
return self._measurements
@property
def axes(self):
return self._axes
def _validate_ensemble_indices(self, indices: int | tuple[int, ...] = ()):
if isinstance(indices, int):
indices = (indices,)
num_ensemble_dims = len(self.measurements.ensemble_shape)
explode_axes = self._get_axes_from_axes_types("explode")
overlay_axes = self._get_axes_from_axes_types("overlay")
num_indexing_axes = num_ensemble_dims - len(explode_axes) - len(overlay_axes)
if len(indices) > num_indexing_axes:
raise ValueError
validated_indices = []
j = 0
for i, axes_type in enumerate(self.axes_types):
if axes_type in ("explode", "overlay"):
validated_indices.append(slice(None))
elif j < len(indices):
validated_indices.append(indices[j])
j += 1
elif axes_type == "index":
validated_indices.append(0)
elif axes_type == "range":
validated_indices.append(slice(None))
else:
raise RuntimeError(
"axes type must be one of 'index', 'range', 'explode' or 'overlay'"
)
return tuple(validated_indices)
[docs] def set_ensemble_indices(self, indices: int | tuple[int, ...] = ()):
"""
Set the indices into the ensemble dimensions to select the visualized ensemble members. Interactive
visualization are updated.
Parameters
----------
indices : int or tuple of int
"""
self._indices = self._validate_ensemble_indices(indices)
self.update_artists()
self.update_panel_labels()
@abstractmethod
def update_artists(self):
pass
def get_global_vmin_vmax(
self, vmin: float = None, vmax: float = None
) -> tuple[float, float]:
measurements = self._get_indexed_measurements()
if measurements.is_complex:
measurements = measurements.abs()
if vmin is None:
vmin = float(np.nanmin(measurements.array))
if vmax is None:
vmax = float(np.nanmax(measurements.array))
return vmin, vmax
def set_panel_labels(
self,
labels: str = "metadata",
frameon: bool = True,
loc: str = "upper left",
pad: float = 0.1,
borderpad: float = 0.1,
prop: dict = None,
formatting: str = ".3g",
units: str = None,
**kwargs,
):
labels_type = labels
if labels == "alphabetic":
labels = string.ascii_lowercase
labels = [f"({label})" for label in labels]
if config.get("visualize.use_tex", False):
labels = [f"${label}$" for label in labels]
elif labels == "metadata":
labels = []
for i, measurement in self._generate_measurements(keepdims=True):
titles = []
for axes_metadata in measurement.ensemble_axes_metadata:
titles.append(
axes_metadata.format_title(formatting, units=units, **kwargs)
)
labels.append("\n".join(titles))
# for i, measurement in self.generate_measurements(keepdims=True):
# labels.append(_get_joined_titles(measurement, formatting))
elif (
not isinstance(labels, (tuple, list))
and len(labels) != np.array(self.axes).size
):
raise ValueError()
if prop is None:
prop = {}
for old_label in self._panel_labels:
old_label.remove()
anchored_text = []
for ax, l in zip(np.array(self.axes).ravel(), labels):
at = AnchoredText(
l,
pad=pad,
borderpad=borderpad,
frameon=frameon,
loc=loc,
prop=prop,
**kwargs,
)
at.formatting = formatting
at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax.add_artist(at)
anchored_text.append(at)
self._panel_labels = anchored_text
if labels_type == "metadata":
self._metadata_labels = self._panel_labels
else:
self._metadata_labels = []
def update_panel_labels(self):
for anchored_text, (i, measurement) in zip(
self._metadata_labels, self._generate_measurements(keepdims=True)
):
label = _get_joined_titles(measurement, anchored_text.formatting)
anchored_text.txt.set_text(label)
def animate(
self,
interval=20,
blit=True,
repeat: bool = False,
adjust_scale: bool = True,
**kwargs,
):
def update(i):
self.set_ensemble_indices((i,))
if adjust_scale:
self._update_vmin_vmax()
return self.artists.ravel()
index_axes = self._get_axes_from_axes_types("index")
if len(index_axes) == 0:
raise RuntimeError()
frames = self.measurements.shape[index_axes[0]]
animation = FuncAnimation(
self.fig,
update,
frames=frames,
interval=interval,
blit=blit,
repeat=repeat,
**kwargs,
)
return animation
[docs]class BaseMeasurementVisualization2D(MeasurementVisualization):
[docs] def __init__(
self,
measurements: _BaseMeasurement2D | IndexedDiffractionPatterns,
ax: Axes = None,
common_scale: bool = False,
cbar: bool = False,
explode: bool = None,
figsize: tuple[float, float] = None,
interact: bool = False,
):
# measurements = measurements.compute().to_cpu()
axes_types = _determine_axes_types(
measurements=measurements, explode=explode, overlay=None
)
if "overlay" in axes_types:
raise NotImplementedError
axes = _validate_axes(
measurements=measurements,
ax=ax,
explode=explode,
overlay=None,
cbar=cbar,
common_color_scale=common_scale,
figsize=figsize,
ioff=interact,
)
super().__init__(measurements=measurements, axes=axes, axes_types=axes_types)
self._xunits = None
self._yunits = None
self._scale_units = None
self._xlabel = None
self._ylabel = None
self._column_titles = []
self._row_titles = []
self._artists = None
self._autoscale = config.get("visualize.autoscale", False)
self._common_scale = common_scale
self._size_bars = []
if self.ncols > 1:
self.set_column_titles()
if self.nrows > 1:
self.set_row_titles()
def _get_vmin_vmax(self):
vmin = np.inf
vmax = -np.inf
for norm in self._normalization.ravel():
vmin = min(vmin, norm.vmin)
vmax = max(vmax, norm.vmax)
return vmin, vmax
def _get_power(self):
power = None
for norm in self._normalization.ravel():
if isinstance(norm, colors.PowerNorm):
if power is None:
power = norm.gamma
else:
power = min(power, norm.gamma)
else:
if power is None:
power = 1.0
else:
power = min(power, 1.0)
return power
def set_normalization(
self,
power: float = None,
vmin: float = None,
vmax: float = None,
):
if self._common_scale:
vmin, vmax = self.get_global_vmin_vmax(vmin=vmin, vmax=vmax)
self._normalization = np.zeros(self.axes.shape, dtype=object)
for i, measurement in self._generate_measurements(keepdims=False):
if power == 1.0:
norm = colors.Normalize(vmin=vmin, vmax=vmax)
else:
norm = colors.PowerNorm(gamma=power, vmin=vmin, vmax=vmax)
if measurement.is_complex:
measurement = measurement.abs()
norm.autoscale_None(measurement.array[np.isnan(measurement.array) == 0])
self._normalization[i] = norm
def _update_vmin_vmax(self, vmin: float = None, vmax: float = None):
for norm, measurement in zip(
self._normalization.ravel(), self._generate_measurements(keepdims=False)
):
norm.vmin = vmin
norm.vmax = vmax
def add_area_indicator(self, area_indicator, panel="first", **kwargs):
xlim = self.axes[0, 0].get_xlim()
ylim = self.axes[0, 0].get_ylim()
for i, ax in enumerate(np.array(self.axes).ravel()):
if panel == "first" and i == 0:
area_indicator._add_to_visualization(ax, **kwargs)
elif panel == "all":
area_indicator._add_to_visualization(ax, **kwargs)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
def _update_power(self, power: float = 1.0):
for i, measurement in self._generate_measurements(keepdims=False):
artists = self._artists[i]
norm = self._normalization[i]
if (power != 1.0) and isinstance(norm, colors.Normalize):
self._normalization[i] = colors.PowerNorm(
gamma=power, vmin=norm.vmin, vmax=norm.vmax
)
artists.norm = self._normalization[i]
if (power == 1.0) and isinstance(norm, colors.PowerNorm):
self._normalization[i] = colors.Normalize(
vmin=norm.vmin, vmax=norm.vmax
)
artists.norm = self._normalization[i]
if (power != 1.0) and isinstance(norm, colors.PowerNorm):
self._normalization[i].gamma = power
@property
def artists(self):
return self._artists
@abstractmethod
def set_artists(self):
pass
def set_scale_units(self, units: str = None):
if units is None:
units = self.measurements.metadata.get("units", "")
self._scale_units = units
def set_cbar_labels(self, label: str = None, **kwargs):
if label is None:
label = self.measurements.metadata.get("label", "")
# TODO: make units work more generally
if self._scale_units is None or len(self._scale_units) == 0:
label = f"{label}"
else:
label = f"{label} [{self._scale_units}]"
for cbars in self._cbars.values():
for cbar in cbars:
cbar.set_label(label, **kwargs)
cbar.formatter.set_powerlimits((-3, 3))
cbar.formatter.set_useMathText(True)
cbar.ax.yaxis.set_offset_position("left")
def set_cbar_padding(self, padding: tuple[float, float] = (0.1, 0.1)):
self._axes.set_cbar_padding(padding)
def set_cbar_size(self, fraction: float):
self._axes.set_cbar_size(fraction)
def set_cbar_spacing(self, spacing: float):
self._axes.set_cbar_spacing(spacing)
def set_cbars(self, **kwargs):
cbars = defaultdict(list)
for i, _ in self._generate_measurements():
ax = self.axes[i]
images = self._artists[i]
if isinstance(self.axes, AxesGrid):
if ax in self.axes._caxes.keys():
cax = self.axes._caxes[ax]
else:
continue
if isinstance(images, np.ndarray):
for j, image in enumerate(images):
cbars[ax].append(plt.colorbar(image, cax=cax[j], **kwargs))
else:
cbars[ax].append(plt.colorbar(images, cax=cax[0], **kwargs))
else:
if isinstance(images, np.ndarray):
for j, image in enumerate(images):
cbars[ax].append(plt.colorbar(image, ax=ax, **kwargs))
else:
cbars[ax].append(plt.colorbar(images, ax=ax, **kwargs))
self._cbars = cbars
def set_scalebars(
self,
panel_loc: tuple[int, ...] = ((-1, 0),),
label: str = "",
size: float = None,
loc: str = "lower right",
borderpad: float = 0.5,
formatting: str = ".3f",
size_vertical: float = None,
sep: float = 6,
pad: float = 0.3,
label_top: bool = True,
frameon: bool = False,
**kwargs,
):
if panel_loc == "all":
panel_loc = np.ndindex(self.axes.shape) # noqa
panel_loc = tuple(panel_loc)
elif panel_loc == "upper left":
panel_loc = ((0, -1),)
elif panel_loc == "upper right":
panel_loc = ((-1, -1),)
elif panel_loc == "lower left":
panel_loc = ((0, 0),)
elif panel_loc == "lower right":
panel_loc = ((-1, 0),)
else:
panel_loc = ((0, 0),)
conversion = _get_conversion_factor(
self._xunits, self.measurements.axes_metadata[-2].units
)
if size is None:
size = (
self.measurements.base_axes_metadata[-2].sampling
* self.measurements.base_shape[-2]
/ 3
)
if size_vertical is None:
size_vertical = (
self.measurements.base_axes_metadata[-1].sampling
* self.measurements.base_shape[-1]
/ 20
)
size = conversion * size
size_vertical = conversion * size_vertical
if label is None:
label = f"{size:>{formatting}} {self._xunits}"
for size_bar in self._size_bars:
size_bar.remove()
self._size_bars = []
for ax in panel_loc:
ax = self.axes[ax]
anchored_size_bar = AnchoredSizeBar(
ax.transData,
label=label,
label_top=label_top,
size=size,
borderpad=borderpad,
loc=loc,
size_vertical=size_vertical,
sep=sep,
pad=pad,
frameon=frameon,
**kwargs,
)
ax.add_artist(anchored_size_bar)
self._size_bars.append(anchored_size_bar)
def axis_off(self, spines: bool = True):
for ax in np.array(self.axes).ravel():
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_xticks([])
ax.set_yticks([])
if not spines:
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
def adjust_tight_bbox(self):
# x_extent = self.measurements._plot_extent_x(self._xunits)
# y_extent = self.measurements._plot_extent_y(self._yunits)
# aspect = (y_extent[1] - y_extent[0]) / (x_extent[1] - x_extent[0])
aspect = 1
size_x = self.fig.get_size_inches()[0]
size_y = size_x * aspect
self.fig.set_size_inches((size_x, size_y))
self.fig.subplots_adjust(left=0, bottom=0, right=1.0, top=1)
[docs]class MeasurementVisualization2D(BaseMeasurementVisualization2D):
"""
Show the image(s) using matplotlib.
Parameters
----------
measurements : _BaseMeasurement2D
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis. This is not available for image grids.
cmap : str, optional
Matplotlib colormap name used to map scalar data to colors. Ignored if image array is complex.
power : float
Show image on a power scale.
vmin : float, optional
Minimum of the intensity color scale. Default is the minimum of the array values.
vmax : float, optional
Maximum of the intensity color scale. Default is the maximum of the array values.
common_scale : bool, optional
If True all images in an image grid are shown on the same colorscale, and a single colorbar is created (if
it is requested). Default is False.
cbar : bool, optional
Add colorbar(s) to the image(s). The position and size of the colorbar(s) may be controlled by passing
keyword arguments to `mpl_toolkits.axes_grid1.axes_grid.ImageGrid` through `image_grid_kwargs`.
"""
[docs] def __init__(
self,
measurements: _BaseMeasurement2D,
ax: Axes = None,
cbar: bool = False,
cmap: str = None,
vmin: float = None,
vmax: float = None,
power: float = 1.0,
common_scale: bool = False,
explode: bool = False,
figsize: tuple[float, float] = None,
interact: bool = False,
):
super().__init__(
measurements,
ax=ax,
cbar=cbar,
common_scale=common_scale,
explode=explode,
figsize=figsize,
interact=interact,
)
if cmap is None and measurements.is_complex:
cmap = config.get("visualize.phase_cmap", "hsluv")
elif cmap is None:
cmap = config.get("visualize.cmap", "viridis")
self._normalization = None
self._cmap = cmap
self.set_normalization(power=power, vmin=vmin, vmax=vmax)
self.set_artists()
if cbar:
self.set_cbars()
self.set_scale_units()
self.set_cbar_labels()
self.set_extent()
self.set_xunits()
self.set_yunits()
self.set_xlabels()
self.set_ylabels()
self.set_column_titles()
@property
def _domain_coloring(self):
return self.measurements.is_complex
def set_cbar_labels(self, label: str = None, **kwargs):
if self._domain_coloring:
for cbar1, cbar2 in self._cbars.values():
cbar1.set_label("arg", rotation=0, ha="center", va="top")
cbar1.ax.yaxis.set_label_coords(0.5, -0.02)
cbar1.set_ticks([-np.pi, -np.pi / 2, 0, +np.pi / 2, np.pi])
cbar1.set_ticklabels(
[
r"$-\pi$",
r"$-\dfrac{\pi}{2}$",
"$0$",
r"$\dfrac{\pi}{2}$",
r"$\pi$",
]
)
cbar2.set_label("abs", rotation=0, ha="center", va="top")
cbar2.ax.yaxis.set_label_coords(0.5, -0.02)
cbar2.formatter.set_powerlimits((0, 0))
cbar2.formatter.set_useMathText(True)
cbar2.ax.yaxis.set_offset_position("left")
else:
super().set_cbar_labels(label, **kwargs)
def _get_default_xlabel(self, units: str = None):
return self.measurements.axes_metadata[-2].format_label(units=units)
def _get_default_ylabel(self, units: str = None):
return self.measurements.axes_metadata[-1].format_label(units=units)
def _get_default_xunits(self):
return self.measurements.axes_metadata[-2].units
def _get_default_yunits(self):
return self.measurements.axes_metadata[-1].units
def set_xlim(self):
self.set_extent()
def set_ylim(self):
self.set_extent()
def set_extent(self, extent=None):
if extent is None:
x_extent = self.measurements._plot_extent_x(self._xunits)
y_extent = self.measurements._plot_extent_y(self._yunits)
extent = x_extent + y_extent
for image in self._artists.ravel():
image.set_extent(extent)
def _add_domain_coloring_imshow(self, ax, array, norm):
abs_array = np.abs(array)
alpha = np.clip(norm(abs_array), a_min=0.0, a_max=1.0)
if self._cmap is None:
cmap = config.get("phase_cmap", "hsluv")
else:
cmap = self._cmap
if cmap == "hsluv":
cmap = hsluv_cmap
im1 = ax.imshow(
np.angle(array).T,
origin="lower",
interpolation="none",
alpha=alpha.T,
vmin=-np.pi,
vmax=np.pi,
cmap=cmap,
)
im2 = ax.imshow(
abs_array.T,
origin="lower",
interpolation="none",
cmap="gray",
zorder=-1,
)
return im1, im2
def _add_real_imshow(self, ax, array):
im = ax.imshow(
array.T,
origin="lower",
interpolation="none",
cmap=self._cmap,
)
return im
def set_artists(
self,
):
if self.measurements.is_complex:
artists_per_axes = 2
else:
artists_per_axes = 1
images = np.zeros(self.axes.shape + (artists_per_axes,), dtype=object)
for i, measurement in self._generate_measurements(keepdims=False):
ax = self.axes[i]
norm = self._normalization[i]
if self._domain_coloring:
images[i] = self._add_domain_coloring_imshow(
ax, measurement.array, norm
)
images[i][1].set_norm(norm)
else:
images[i] = self._add_real_imshow(ax, measurement.array)
images[i][0].set_norm(norm)
if images.shape[-1] == 1:
images = np.squeeze(images, -1)
self._artists = images
def _update_domain_coloring_alpha(self, values, image, normalization):
alpha = normalization(values)
alpha = np.clip(alpha, a_min=0, a_max=1)
image.set_alpha(alpha)
def _update_vmin_vmax(self, vmin: float = None, vmax: float = None):
super()._update_vmin_vmax(vmin=vmin, vmax=vmax)
if self._domain_coloring:
for i, measurement in self._generate_measurements(keepdims=False):
images = self._artists[i]
abs_array = np.abs(measurement.array).T
self._update_domain_coloring_alpha(
abs_array, images[0], self._normalization[i]
)
def update_artists(self):
for i, measurement in self._generate_measurements(keepdims=False):
images = self._artists[i]
array = measurement.array.T
if self._domain_coloring:
abs_array = np.abs(array)
self._update_domain_coloring_alpha(
abs_array, images[0], self._normalization[i]
)
images[0].set_data(np.angle(array))
images[1].set_data(abs_array)
else:
images.set_data(array)
@property
def widgets(self):
if widgets is None:
raise ipywidgets_not_installed
canvas = self.fig.canvas
def index_update_callback(change):
if self._autoscale:
vmin, vmax = self.get_global_vmin_vmax()
self._update_vmin_vmax(vmin, vmax)
sliders = make_sliders_from_ensemble_axes(
self,
self.axes_types, # callbacks=(index_update_callback,)
)
power_scale_button = _make_power_scale_slider(self)
scale_button = _make_scale_button(self)
autoscale_button = _make_autoscale_button(self)
continuous_update_button = _make_continuous_button(sliders)
scale_button.layout = widgets.Layout(width="20%")
autoscale_button.layout = widgets.Layout(width="30%")
continuous_update_button.layout = widgets.Layout(width="50%")
scale_box = widgets.VBox(
[widgets.HBox([scale_button, autoscale_button, continuous_update_button])]
)
scale_box.layout = widgets.Layout(width="300px")
gui = widgets.VBox(
[
widgets.VBox(sliders),
scale_box,
# vmin_vmax_slider,
power_scale_button,
]
)
return widgets.HBox([gui, canvas])
[docs]class MeasurementVisualization1D(MeasurementVisualization):
[docs] def __init__(
self,
measurements: _BaseMeasurement1D,
ax: Axes = None,
common_scale: bool = True,
explode: Sequence[str] | bool = False,
overlay: Sequence[str] | bool = False,
figsize: tuple[float, float] = None,
interact: bool = False,
**kwargs
):
axes_types = _determine_axes_types(
measurements, explode=explode, overlay=overlay
)
axes = _validate_axes(
measurements=measurements,
ax=ax,
explode=explode,
overlay=overlay,
cbar=False,
common_color_scale=False,
figsize=figsize,
aspect=False,
ioff=interact,
sharey=common_scale,
)
super().__init__(measurements=measurements, axes=axes, axes_types=axes_types)
self._xunits = None
self._yunits = None
self._xlabel = None
self._ylabel = None
self._column_titles = []
self._lines = np.array([[]])
self._common_scale = common_scale
self.set_artists(**kwargs)
self.set_xunits()
self.set_yunits()
self._autoscale = config.get("visualize.autoscale", False)
if self.ncols > 1:
self.set_column_titles()
#
# for i, _ in self.iterate_measurements():
# # #self.axes[i].yaxis.set_label_coords(0.5, -0.02)
# # #cbar2.formatter.set_powerlimits((0, 0))
# # #self.axes[i].get_yaxis().formatter.set_useMathText(True)
# # #self.axes[i].ticklabel_format(style='sci', axis='x', scilimits=(0, 0), useMathText=True)
# self.axes[i].get_yaxis().get_offset_text().set_horizontalalignment("right")
# # #
# #self.axes[i].yaxis.set_offset_position("right")
# #self.axes[i].yaxis.set_offset_position("left")
# #self.axes[i].set_ylabel(format_label(self._y_label, self._y_units))
@property
def artists(self):
return self._artists
def _get_default_xlabel(self, units: str = None):
return self.measurements.axes_metadata[-1].format_label(units)
def _get_default_ylabel(self, units: str = None):
axes = LinearAxis(label=self.measurements.metadata.get("label", ""))
return format_label(axes, units)
def _get_default_xunits(self):
return self.measurements.axes_metadata[-1].units
def _get_default_yunits(self):
return self.measurements.metadata.get("units", "")
def set_xlim(self, xlim=None):
extent = self.measurements._plot_extent(self._xunits)
margin = (extent[1] - extent[0]) * 0.05
if xlim is None:
xlim = [-extent[0] - margin, extent[1] + margin]
for i, measurement in self._generate_measurements():
self.axes[i].set_xlim(xlim)
artists = self.artists[i]
for artist in artists:
x = self._get_xdata()
artist.set_xdata(x)
def set_ylim(self, ylim=None):
def _get_extent(measurements):
min_value = measurements.min()
max_value = measurements.max()
margin = (max_value - min_value) * 0.05
return [min_value - margin, max_value + margin]
if self._common_scale and ylim is None:
common_ylim = _get_extent(self.measurements)
else:
common_ylim = ylim
for i, measurement in self._generate_measurements():
if common_ylim is None:
ylim = _get_extent(measurement)
else:
ylim = common_ylim
self.axes[i].set_ylim(ylim)
def set_legends(self, loc: str = "first", **kwargs):
indices = [index for index in np.ndindex(*self.axes.shape)]
if loc == "first":
loc = indices[:1]
elif loc == "last":
loc = indices[-1:]
elif loc == "all":
loc = indices
for i, _ in self._generate_measurements():
if i in loc:
self.axes[i].legend(**kwargs)
def _update_vmin_vmax(self, vmin: float = None, vmax: float = None):
self.set_ylim([vmin, vmax])
# for _, measurement in self._generate_measurements(keepdims=False):
# self.set_ylim([measurement.min(), measurement.max()])
def set_artists(self, **kwargs):
artists = np.zeros(self.axes.shape, dtype=object)
for i, measurement in self._generate_measurements(keepdims=False):
ax = self.axes[i]
x = self._get_xdata()
new_lines = []
for _, line_profile in measurement.generate_ensemble(keepdims=True):
if not "label" in kwargs:
labels = []
for axis in line_profile.ensemble_axes_metadata:
labels += [axis.format_title(".3f")]
kwargs["label"] = "-".join(labels)
new_lines.append(
ax.plot(
x,
line_profile.array[(0,) * (len(line_profile.shape) - 1)],
**kwargs
)[0]
)
artists.itemset(i, new_lines)
self._artists = artists
def _get_xdata(self):
extent = self.measurements._plot_extent(self._xunits)
return np.linspace(
extent[0],
extent[1],
self.measurements.shape[-1],
endpoint=False,
)
def update_artists(self):
for i, measurements in self._generate_measurements(keepdims=False):
lines = self._artists[i]
for line, measurement in zip(lines, measurements):
y = measurement.array
x = self._get_xdata()
line.set_data(x, y)
@property
def widgets(self):
if widgets is None:
raise ipywidgets_not_installed
canvas = self.fig.canvas
# def index_update_callback(change):
# if self._autoscale:
# vmin, vmax = self.get_global_vmin_vmax()
# self._update_vmin_vmax(vmin, vmax)
sliders = make_sliders_from_ensemble_axes(
self,
self.axes_types, # callbacks=(index_update_callback,)
)
# power_scale_button = _make_power_scale_slider(self)
scale_button = _make_scale_button(self)
autoscale_button = _make_autoscale_button(self)
continuous_update_button = _make_continuous_button(sliders)
scale_button.layout = widgets.Layout(width="20%")
autoscale_button.layout = widgets.Layout(width="30%")
continuous_update_button.layout = widgets.Layout(width="50%")
scale_box = widgets.VBox(
[widgets.HBox([scale_button, autoscale_button, continuous_update_button])]
)
scale_box.layout = widgets.Layout(width="300px")
gui = widgets.VBox(
[
widgets.VBox(sliders),
scale_box,
# vmin_vmax_slider,
# power_scale_button,
]
)
return widgets.HBox([gui, canvas])
[docs]class DiffractionSpotsVisualization(BaseMeasurementVisualization2D):
"""
Display a diffraction pattern as indexed Bragg reflections.
Parameters
----------
measurements : IndexedDiffractionPattern
Diffraction pattern to be displayed.
scale : float
Size of the circles representing the diffraction spots.
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis.
Returns
-------
figure, axis_handle : matplotlib.figure.Figure, matplotlib.axis.Axis
"""
[docs] def __init__(
self,
measurements: IndexedDiffractionPatterns,
ax: Axes,
cbar: bool = False,
cmap: str = None,
vmin: float = None,
vmax: float = None,
power: float = 1.0,
scale: float = 0.1,
common_scale: bool = False,
explode: bool = False,
figsize: tuple[float, float] = None,
interact: bool = False,
):
measurements = measurements.sort(criterion="distance")
super().__init__(
measurements,
ax=ax,
cbar=cbar,
common_scale=common_scale,
explode=explode,
figsize=figsize,
interact=interact,
)
# positions = measurements.positions[:, :2]
self._scale = scale
# (
# np.sqrt(np.min(squareform(distance_matrix(positions, positions)))) * scale
# )
if cmap is None:
cmap = config.get("visualize.cmap", "viridis")
self._normalization = None
self._scale_normalization = None
self._annotation_threshold = 0.0
self._cmap = cmap
self._size_bars = []
self._miller_index_annotations = None
self.set_normalization(power=power, vmin=vmin, vmax=vmax)
self.set_artists()
if cbar:
self.set_cbars()
self.set_scale_units()
self.set_cbar_labels()
self.set_xunits()
self.set_yunits()
self.set_xlabels()
self.set_ylabels()
self.set_xlim()
self.set_ylim()
def _get_scales(self, indexed_diffraction_spots, norm):
conversion = _get_conversion_factor(self._xunits, self._get_default_xunits())
return (
norm(indexed_diffraction_spots.intensities) ** 0.5
* self._scale
* 0.5
* conversion
)
def _get_positions(self, indexed_diffraction_spots):
positions = indexed_diffraction_spots.positions[:, :2].copy()
positions[:, 0] *= _get_conversion_factor(
self._xunits, self._get_default_xunits()
)
positions[:, 1] *= _get_conversion_factor(
self._yunits, self._get_default_yunits()
)
return positions
def _update_scales(self):
for i, measurement in self._generate_measurements(keepdims=False):
artists = self._artists[i]
norm = self._normalization[i]
scales = self._get_scales(measurement, norm)
artists._widths = scales
artists._heights = scales
artists.set()
def _update_vmin_vmax(self, vmin: float = None, vmax: float = None):
super()._update_vmin_vmax(vmin, vmax)
self._update_scales()
def _update_power(self, power: float = 1.0):
super()._update_power(power)
self._update_scales()
def set_artists(self):
if self._artists is not None:
for artist in self.artists.ravel():
artist.remove()
self._artists = np.zeros(self.axes.shape, dtype=object)
for i, measurement in self._generate_measurements(keepdims=False):
ax = self.axes[i]
norm = self._normalization[i]
scales = self._get_scales(measurement, norm)
positions = self._get_positions(measurement)
if self._cmap not in plt.colormaps():
cmap = ListedColormap([self._cmap])
else:
cmap = self._cmap
ellipse_collection = EllipseCollection(
widths=scales,
heights=scales,
angles=0.0,
units="xy",
array=measurement.intensities,
cmap=cmap,
offsets=positions,
transOffset=ax.transData,
)
ellipse_collection.set_norm(norm)
ax.add_collection(ellipse_collection)
self._artists[i] = ellipse_collection
ax.axis("equal")
@property
def _reciprocal_space_axes(self):
return [
ReciprocalSpaceAxis(
label="kx", sampling=1.0, units="1/Å", _tex_label="$k_x$"
),
ReciprocalSpaceAxis(
label="ky", sampling=1.0, units="1/Å", _tex_label="$k_y$"
),
]
def _get_default_xlabel(self, units: str = None):
return self._reciprocal_space_axes[-2].format_label(units)
def _get_default_ylabel(self, units: str = None):
return self._reciprocal_space_axes[-1].format_label(units)
def _get_default_xunits(self):
return self._reciprocal_space_axes[-1].units
def _get_default_yunits(self):
return self._reciprocal_space_axes[-1].units
[docs] def set_xlim(self, xlim: tuple[float, float] = None):
"""
Set the x-axis view limits.
"""
if xlim is not None:
common_xlim = True
else:
common_xlim = False
for i, measurement in self._generate_measurements():
if common_xlim is False:
xlim = np.abs(measurement.positions[:, 0]).max() * 1.2
xlim = (
_get_conversion_factor(self._xunits, self._get_default_xunits())
* xlim
)
xlim = [-xlim, xlim]
if xlim is not None:
self.axes[i].set_xlim(xlim)
[docs] def set_ylim(self, ylim: tuple[float, float] = None):
"""
Set the y-axis view limits.
"""
if ylim is not None:
common_ylim = True
else:
common_ylim = False
for i, measurement in self._generate_measurements():
if common_ylim is False:
ylim = np.abs(measurement.positions[:, 1]).max() * 1.2
ylim = (
_get_conversion_factor(self._xunits, self._get_default_xunits())
* ylim
)
ylim = [-ylim, ylim]
if ylim is not None:
self.axes[i].set_ylim(ylim)
[docs] def set_xunits(self, units: str = None):
super().set_xunits(units)
self.set_artists()
[docs] def set_yunits(self, units: str = None):
super().set_yunits(units)
self.set_artists()
# for i, measurement in self.iterate_measurements():
# artist = self.artists[i]
# positions = measurement.positions[:, :2].copy()
# positions[:, 0] *= _get_conversion_factor(self._x_units, self._get_default_x_units())
# positions[:, 1] *= _get_conversion_factor(self._y_units, self._get_default_y_units())
# artist.set(offsets=positions)
def update_artists(self):
for i, measurement in self._generate_measurements(keepdims=False):
artists = self._artists[i]
norm = self._normalization[i]
scales = self._get_scales(measurement, norm)
artists._widths = np.clip(scales, a_min=1e-3, a_max=1e3)
artists._heights = np.clip(scales, a_min=1e-3, a_max=1e3)
artists.set(array=measurement.intensities)
self._set_hkl_visibility()
def remove_miller_index_annotations(self):
for annotation in self._miller_index_annotations:
annotation.remove()
self._miller_index_annotations = []
def set_hkl_threshold(self, threshold):
self._annotation_threshold = threshold
self._set_hkl_visibility()
def _set_hkl_visibility(self):
if self._miller_index_annotations is None:
self.set_miller_index_annotations()
for i, measurement in self._generate_measurements(keepdims=False):
visibility = measurement.intensities > self._annotation_threshold
for annotation, visible in zip(self._miller_index_annotations, visibility):
annotation.set_visible(visible)
def set_miller_index_annotations(
self,
threshold: float = 1.0,
size: int = 8,
alignment: str = "top",
**kwargs,
):
self._annotation_threshold = threshold
self._miller_index_annotations = []
for i, measurement in self._generate_measurements(keepdims=False):
ax = self.axes[i]
norm = self._normalization[i]
visibility = measurement.intensities > threshold
positions = self._get_positions(measurement)
scales = self._get_scales(measurement, norm)
for hkl, position, visible, scale in zip(
measurement.miller_indices, positions, visibility, scales
):
if alignment == "top":
xy = position[:2] + [0, scale / 2]
va = "bottom"
elif alignment == "center":
xy = position[:2]
va = "center"
elif alignment == "bottom":
xy = position[:2] - [0, scale / 2]
va = "top"
else:
raise ValueError()
if config.get("visualize.use_tex"):
text = " \ ".join(
[f"\\bar{{{abs(i)}}}" if i < 0 else f"{i}" for i in hkl]
)
text = f"${text}$"
else:
text = "{} {} {}".format(*hkl)
annotation = ax.annotate(
text,
xy=xy,
ha="center",
va=va,
size=size,
visible=visible,
**kwargs,
)
self._miller_index_annotations.append(annotation)
def pick_events(self):
self._pick_annotations = {}
for ax, artist in zip(np.array(self.axes).ravel(), self.artists.ravel()):
artist.set_picker(True)
annotation = ax.annotate(
"",
xy=(0, 0),
xycoords="data",
xytext=(20.0, 20.0),
textcoords="offset points",
bbox=dict(boxstyle="round", fc="w"),
arrowprops=dict(arrowstyle="->"),
visible=False,
)
self._pick_annotations[artist] = annotation
def onpick(event):
hkl = self.measurements.miller_indices[event.ind][0].tolist()
position = self.measurements.positions[event.ind][0]
intensity = event.artist.get_array()[event.ind].item()
annotation = self._pick_annotations[event.artist]
annotation.set_text(
"\n".join(
(
f"hkl: {' '.join(map(str, hkl))}",
f"coordinate: {'{:.2f}, {:.2f}, {:.2f}'.format(*position.tolist())}",
f"intensity: {intensity:.4g}",
)
)
)
annotation.xy = position[:2]
annotation.set_visible(True)
self.fig.canvas.mpl_connect("pick_event", onpick)
@property
def widgets(self):
if widgets is None:
raise ipywidgets_not_installed
canvas = self.fig.canvas
sliders = make_sliders_from_ensemble_axes(self, self.axes_types)
def index_update_callback(change):
if self._autoscale:
vmin, vmax = self.get_global_vmin_vmax()
self._update_vmin_vmax(vmin, vmax)
_set_update_indices_callback(sliders, self, callbacks=(index_update_callback,))
def hkl_slider_changed(change):
self.set_hkl_threshold(change["new"])
hkl_slider = widgets.FloatLogSlider(
description="Index threshold", min=-10, max=0, value=1, step=1e-6
)
hkl_slider.observe(hkl_slider_changed, "value")
power_scale_slider = _make_power_scale_slider(self)
scale_button = _make_scale_button(self)
autoscale_button = _make_autoscale_button(self)
continuous_update_button = _make_continuous_button(sliders)
scale_button.layout = widgets.Layout(width="20%")
autoscale_button.layout = widgets.Layout(width="30%")
continuous_update_button.layout = widgets.Layout(width="50%")
scale_box = widgets.VBox(
[widgets.HBox([scale_button, autoscale_button, continuous_update_button])]
)
scale_box.layout = widgets.Layout(width="300px")
gui = widgets.VBox(
[
widgets.VBox(sliders),
scale_box,
# vmin_vmax_slider,
power_scale_slider,
hkl_slider,
]
)
return widgets.HBox([gui, canvas])
_cube = np.array(
[
[[0, 0, 0], [0, 0, 1]],
[[0, 0, 0], [0, 1, 0]],
[[0, 0, 0], [1, 0, 0]],
[[0, 0, 1], [0, 1, 1]],
[[0, 0, 1], [1, 0, 1]],
[[0, 1, 0], [1, 1, 0]],
[[0, 1, 0], [0, 1, 1]],
[[1, 0, 0], [1, 1, 0]],
[[1, 0, 0], [1, 0, 1]],
[[0, 1, 1], [1, 1, 1]],
[[1, 0, 1], [1, 1, 1]],
[[1, 1, 0], [1, 1, 1]],
]
)
def _merge_columns(atoms: Atoms, plane, tol: float = 1e-7) -> Atoms:
uniques, labels = np.unique(atoms.numbers, return_inverse=True)
new_atoms = Atoms(cell=atoms.cell)
for unique, indices in zip(uniques, label_to_index(labels)):
positions = atoms.positions[indices]
positions = _merge_positions(positions, plane, tol)
numbers = np.full((len(positions),), unique)
new_atoms += Atoms(positions=positions, numbers=numbers)
return new_atoms
def _merge_positions(positions, plane, tol: float = 1e-7) -> np.ndarray:
axes = plane_to_axes(plane)
rounded_positions = tol * np.round(positions[:, axes[:2]] / tol)
unique, labels = np.unique(rounded_positions, axis=0, return_inverse=True)
new_positions = np.zeros((len(unique), 3))
for i, label in enumerate(label_to_index(labels)):
top_atom = np.argmax(-positions[label][:, axes[2]])
new_positions[i] = positions[label][top_atom]
# new_positions[i, axes[2]] = np.max(positions[label][top_atom, axes[2]])
return new_positions
[docs]def show_atoms(
atoms: Atoms,
plane: tuple[float, float] | str = "xy",
ax: Axes = None,
scale: float = 0.75,
title: str = None,
numbering: bool = False,
show_periodic: bool = False,
figsize: tuple[float, float] = None,
legend: bool = False,
merge: float = 1e-2,
tight_limits: bool = False,
show_cell: bool = None,
**kwargs,
):
"""
Display 2D projection of atoms as a matplotlib plot.
Parameters
----------
atoms : ase.Atoms
The atoms to be shown.
plane : str, two float
The projection plane given as a concatenation of 'x' 'y' and 'z', e.g. 'xy', or as two floats representing the
azimuth and elevation angles of the viewing direction [degrees], e.g. (45, 45).
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axes.
scale : float
Factor scaling their covalent radii for the atom display sizes (default is 0.75).
title : str
Title of the displayed image. Default is None.
numbering : bool
Display the index of the Atoms as a number. Default is False.
show_periodic : bool
If True, show the periodic images of the atoms at the cell boundary.
figsize : two int, optional
The figure size given as width and height in inches, passed to `matplotlib.pyplot.figure`.
legend : bool
If True, add a legend indicating the color of the atomic species.
merge: float
To speed up plotting large numbers of atoms, those closer than the given value [Å] are merged.
tight_limits : bool
If True the limits of the plot are adjusted
kwargs : Keyword arguments for matplotlib.collections.PatchCollection.
Returns
-------
matplotlib.figure.Figure, matplotlib.axes.Axes
"""
if show_periodic:
atoms = atoms.copy()
atoms = pad_atoms(atoms, margins=1e-3)
if merge > 0.0:
atoms = _merge_columns(atoms, plane, merge)
if tight_limits and show_cell is None:
show_cell = False
elif show_cell is None:
show_cell = True
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
cell = atoms.cell
axes = plane_to_axes(plane)
cell_lines = np.array(
[[np.dot(line[0], cell), np.dot(line[1], cell)] for line in _cube]
)
cell_lines_x, cell_lines_y = cell_lines[..., axes[0]], cell_lines[..., axes[1]]
if show_cell:
for cell_line_x, cell_line_y in zip(cell_lines_x, cell_lines_y):
ax.plot(cell_line_x, cell_line_y, "k-")
if len(atoms) > 0:
positions = atoms.positions[:, axes[:2]]
order = np.argsort(-atoms.positions[:, axes[2]])
positions = positions[order]
colors = jmol_colors[atoms.numbers[order]]
sizes = covalent_radii[atoms.numbers[order]] * scale
circles = []
for position, size in zip(positions, sizes):
circles.append(Circle(position, size))
coll = PatchCollection(circles, facecolors=colors, edgecolors="black", **kwargs)
ax.add_collection(coll)
ax.axis("equal")
ax.set_xlabel(plane[0] + " [Å]")
ax.set_ylabel(plane[1] + " [Å]")
ax.set_title(title)
if numbering:
if merge:
raise ValueError("atom numbering requires 'merge' to be False")
for i, (position, size) in enumerate(zip(positions, sizes)):
ax.annotate(
"{}".format(order[i]), xy=position, ha="center", va="center"
)
if legend:
legend_elements = [
Line2D(
[0],
[0],
marker="o",
color="w",
markeredgecolor="k",
label=chemical_symbols[unique],
markerfacecolor=jmol_colors[unique],
markersize=12,
)
for unique in np.unique(atoms.numbers)
]
ax.legend(handles=legend_elements, loc="upper right")
if tight_limits:
ax.set_adjustable("box")
ax.set_xlim([np.min(cell_lines_x), np.max(cell_lines_x)])
ax.set_ylim([np.min(cell_lines_y), np.max(cell_lines_y)])
return fig, ax