Hide code cell source
import ase
import dask
from bokeh.io import output_notebook
from IPython.display import Image, display

import abtem

output_notebook()
Loading BokehJS ...

Parallelization#

The computational cost of running multislice simulations can become high depending on the size of the specimen, the number of probe positions, phonon images, and many other factors. This cost can be mitigated using parallelism. Many of the computations in abTEM are embarrasingly parallel: for example, every probe position is independent, thus each CPU core may calculate a batch of positions independently, only requiring communication after finishing a run of the multislice algorithm.

abTEM is parallelized using Dask[DaskDTeam16], which is a flexible library for parallel computing in Python. Dask allows scaling from a single laptop to hundreds of nodes at high-performance computing (HPC) facilities with minimal changes to the code.

In this walkthrough, we introduce how abTEM uses Dask. Althouhg this is not required knowledge for running simulations on a single machine, it may still help you optimize your simulations. If you are already an experienced Dask user, most of what you already know can be applied to using abTEM. If you are new to Dask you may benefit from watching this introduction before continuing.

We note that Dask is used in several other libraries for the analysis of electron microscopy data, for example, hyperspy, libertem and py4DSTEM, and thus we think that you may benefit from knowing this library more generally.

Task graphs#

Simulating TEM experiments requires executing multiple tasks where each one may depend on the output of previous tasks. In Dask this is represented as a task graph, where each task is a node, with edges between nodes if it is dependent on another task. The simulation result is obtained by executing each task (node) in the graph with a Dask scheduler on a single machine or a cluster.

To illustrate this in practice, below we create the task graph for running a plane wave multislice simulation of gold in the \(\left<100\right>\) zone axis with 4 frozen phonon configurations (please refer to previous walkthroughs for those details).

atoms = ase.build.bulk("Au", cubic=True) * (5, 5, 2)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=False
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

probe = abtem.PlaneWave(energy=200e3)

exit_waves = probe.multislice(potential)

The result is an ensemble of \(4\) wave functions of shape \(512\times512\), which may be represented as a 3D array, where the first dimension represents the phonon ensemble and the last two dimensions represents the 2D wave functions.

As we have not executed the task graph yet, the wave functions are represented as a Dask Array. We can think of the Dask array as being composed of many smaller NumPy arrays, called chunks, and operations may be applied to each chunk rather than the full array. This enables:

  1. Parallelism over the chunks;

  2. Representing a larger-than-memory array as many smaller arrays, each of which fits in memory.

The Dask array representation __repr__ shows how the chunks are laid out.

exit_waves.array
Array Chunk
Bytes 8.00 MiB 2.00 MiB
Shape (4, 512, 512) (1, 512, 512)
Dask graph 4 chunks in 18 graph layers
Data type complex64 numpy.ndarray
512 512 4

We see that this Dask array has the shape (4, 512, 512) requiring 8 MB of memory, and is composed of chunks with shape (1, 512, 512) requiring 2 MB each.

Important

The Dask array just represents a task graph, not the data itself, so memory is consumed only if it is computed!

Each chunk of the Dask array created above represents a wave function for a frozen phonon configuration. This reflects the fact that, in the multislice algorithm, each frozen phonon configuration is independent and may be calculated in parallel. On the other hand, we should not have chunks across wave functions, because each part of a wave function is affected by every other part.

We can visualize the task graph using Dask’s visualize method. We see that the task graph consists of 4 fully independent branches, one for each frozen phonon.

Note

Drawing Dask task graphs with the cytoscape engine requires the ipycytoscape python library. To reproduce the result below you need to:

  python -m pip install ipycytoscape

and restart jupyter.

exit_waves.array.visualize(engine="cytoscape")

We usually want to take the mean across the frozen phonon dimension, and thus end up with an image represented as a single chunk.

hrtem_image = exit_waves.intensity().mean(0)

hrtem_image.array
Array Chunk
Bytes 1.00 MiB 1.00 MiB
Shape (512, 512) (512, 512)
Dask graph 1 chunks in 22 graph layers
Data type float32 numpy.ndarray
512 512

Taking the mean across frozen phonon chunks requires communicating the exit wave function intensity between the final nodes. Showing the task graph again we see how the branches are merged when the result have to be communicated between workers.

hrtem_image.array.visualize(engine="cytoscape")

Chunks#

To futher explore the role of chunks in abTEM, we create the task graph for running a STEM simulation with the same specimen, gold in the \(\left<100\right>\) zone axis with 4 frozen phonons. We do not immediately apply a detector, so we obtain an ensemble of exit wave functions.

probe = abtem.Probe(energy=200e3, semiangle_cutoff=20)

scan = abtem.GridScan(
    start=(0, 0),
    end=(1 / 5, 1 / 5),
    fractional=True,
    potential=potential,
)

frozen_phonons = abtem.FrozenPhonons(
    atoms, num_configs=4, sigmas=0.1, ensemble_mean=True
)

potential = abtem.Potential(frozen_phonons, gpts=512, slice_thickness=2)

exit_waves_stem = probe.multislice(potential, scan=scan)

exit_waves_stem.axes_metadata
type               label           coordinates
-----------------  --------------  -------------------
FrozenPhononsAxis  Frozen phonons  -
ScanAxis           x [Å]           0.00 0.29 ... 3.79
ScanAxis           y [Å]           0.00 0.29 ... 3.79
RealSpaceAxis      x [Å]           0.00 0.04 ... 20.36
RealSpaceAxis      y [Å]           0.00 0.04 ... 20.36

The wave functions are now represented as a 5D Dask array, a 3D ensemble of 2D wave functions, which is composed of one frozen phonon dimension and two scan dimensions, one for each of the \(x\) and \(y\) direction.

The full array is of shape (4, 15, 15, 512, 512) requiring 1.76 GB of memory, which is cut into chunks of shape (1, 8, 7, 512, 512) of 112 MB. Hence, there are a total of \(4 \times 2 \times 3 = 24\) chunks.

exit_waves_stem.array
Array Chunk
Bytes 1.53 GiB 112.00 MiB
Shape (4, 14, 14, 512, 512) (1, 8, 7, 512, 512)
Dask graph 16 chunks in 22 graph layers
Data type complex64 numpy.ndarray
14 4 512 512 14

Notice that we do not make a chunk for every probe position, but each chunk of the scan dimension instead represents a batch of wave functions. This is done partly to limit the overhead that every chunk comes with, but more importantly, larger batches enables efficient thread parallelization within each run of the multislice algorithm.

We can change how many wave functions each batch should have using a keyword. Below we set max_batch=4, resulting in a total number of \(4 \times 8 \times 8 = 256\) chunks.

exit_waves_stem = probe.multislice(potential, scan=scan, max_batch=4)

exit_waves_stem.array
Array Chunk
Bytes 1.53 GiB 8.00 MiB
Shape (4, 14, 14, 512, 512) (1, 2, 2, 512, 512)
Dask graph 196 chunks in 21 graph layers
Data type complex64 numpy.ndarray
14 4 512 512 14

The default value of max_batch is "auto": with this setting the number of wave functions in each batch is determined such that the batch represents approximately 128 MB of memory, although this number may be changed through the abTEM configuration.

Before computing this task graph, we apply a HAADF detector and calculate the ensemble mean, which reduces the total size of the output to just 900 B. We note that the 1.76 GB ensemble of wave functions never needs to be in memory simulateneously, as each chunk of exit wave functions are reduced immediately after completing the multislice algorithm.

detector = abtem.AnnularDetector(inner=65, outer=200)

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.array
Array Chunk
Bytes 784 B 16 B
Shape (14, 14) (2, 2)
Dask graph 49 chunks in 26 graph layers
Data type float32 numpy.ndarray
14 14

Schedulers#

After generating a task graph, it needs to be executed on (parallel) hardware. This is the job of a task scheduler. Dask provides several task schedulers: each of which will compute a task graph and give the same result, but with different performance characteristics.

Every time you call the compute method a Dask scheduler is used. We adopt the default Dask scheduler configuration and every keyword argument used with the compute method in abTEM is forwarded to the Dask compute function.

Local scheduler#

The default scheduler is the ThreadPoolExecutor. As an example of a keyword argument, the threaded scheduler takes num_workers, which sets the number of threads to use (by default, up to the number of computing cores). Let’s run compute with 8 workers.

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images.compute(scheduler="threads", num_workers=8)
[########################################] | 100% Completed | 19.70 s
<abtem.measurements.Images at 0x2a9a0741dd0>

We can change the scheduler to using the ProcessPoolExecutor as below.

haadf_images = detector.detect(exit_waves_stem).compute(scheduler="processes", num_workers=4)

Using abtem.config.set the scheduler can be set either as a context manager or globally.

# As a context manager
with abtem.config.set(scheduler="processes"):
    haadf_images.compute()

# Set globally
abtem.config.set(scheduler="processes")
haadf_images.compute()

Profiling#

To improve performance, we have to be able to profile it. Profiling parallel code can be challenging, but Dask provides functionality to aid in this and inspecting execution. The diagnostic tools are however quite different depending on whether you use a local or distributed scheduler.

Dask allows local diagnostics by adding callbacks that collect information about your code execution. You can use the profilers as a context manager as described in the Dask documentation. For convenience, the abTEM compute methods also implement keywords for adding Dask profilers.

Below we use the Profiler to monitor task execution by setting profiler=True and a ResourceProfiler to monitor the CPU usage and memory consumption by setting resource_profiler=True. We rerun the simulation above with these profilers.

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()

haadf_images, profilers = haadf_images.compute(profiler=True, resource_profiler=True)
[########################################] | 100% Completed | 19.44 s

To display the results Dask uses the Bokeh plotting library (which is installed together with Dask). To display the plots in a Jupyter notebook we need to run the commands below.

We first show the result from the Profiler object: This shows the execution time for each task as a rectangle, organized along the vertical axis by worker (in this case threads), and where white space represents times the worker was idle. The task types are grouped by color and, by hovering over each task, one can see the key and task that each block represents. For this calculation there is only one significant task shown in yellow; this encompasses building the wave function, running the multislice algorithm (calculating the potential on demand), and calculating and integrating the diffraction patterns.

profilers[0].visualize();

The result from the ResourceProfiler is shown below: This shows two lines, one for total CPU percentage used by all the workers, and one for total memory usage. The CPU usage is scaled so each worker contributes up to \(100 \ \%\), i.e. two fully utilized workers use \(200 \ \%\).

profilers[1].visualize();

We ran the calculation on a single computer with an 4-core CPU with 2 threads per core for 8 threads total. We see that our peak CPU usage was \(\sim 600 \%\), this is a fairly typical usage statistic for a single machine, overhead and background processes limited us from reaching the \(800 \ \%\) corresponding to using every available thread maximally.

We also see that the total memory use reached around 800 MB. If you are running the calculation on a system with more threads your memory consumption may be larger – it may even exceed the total memory cost of all the wave functions, as every parallel run of the multislice algorithm requires a significant overhead for intermediate results (such as potential slices and Fresnel propagators). If your calculation runs out of memory you can lower the number of workers, thus trading away computational speed for lower memory consumption.

We note that the overhead in both CPU usage and memory diminishes for larger simulation with more powerful hardware.

The (locally) distributed scheduler#

The Dask distributed scheduler is necessary for running your simulations on a cluster, but it also runs locally on a personal machine. You can find details in the Dask documentation, but we demonstrate the basics below.

You can use the Dask distributed scheduler by just initializing a Dask Client. The Client takes keyword arguments such as n_workers (note that this is different from num_workers above!)

from dask.distributed import Client

client = Client(n_workers=6)
client
C:\Users\jacob\miniconda3\envs\abtem-devel\Lib\site-packages\distributed\node.py:182: UserWarning: Port 8787 is already in use.
Perhaps you already have a cluster running?
Hosting the HTTP server on port 58706 instead
  warnings.warn(

Client

Client-e5a450c3-03be-11ef-bee0-3800259eb785

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:58706/status

Cluster Info

After intializing the client object, any abTEM computation will use the Dask distributed scheduler.

haadf_images = detector.detect(exit_waves_stem).reduce_ensemble()
haadf_images.compute()
<abtem.measurements.Images at 0x2a9bb169010>

A benefit of using the distributed scheduler on a single machine is the live diagnostic dashboard. You can access this through the Dashboard link shown in the __repr__ for the Client above (for details you can watch this video walkthrough). If you are using Jupyter Lab, the Dask labextension provides the same information as a panel inside the Jupyter Lab editor.

We can get back to the local scheduler by closing the Client.

client.close()

Running abTEM on HPC clusters#

Dask (and thus abTEM) has robust tools for deployment on high-performance compute clusters. We recommend consulting your HPC provider on how to deploy Dask applications on your available cluster. For general advice on deployment, please see the Dask documentation.

Submitting job scripts#

As an overview, Dask provides a number of different cluster managers, so you can use distributed Dask on a range of platforms. These cluster managers deploy a scheduler and the necessary workers as determined by communicating with the resource manager. All cluster managers follow the same interface but have platform-specific configuration options.

As an example, for deployment using SLURM, your script might look something like:

from dask_jobqueue import SLURMCluster
from dask.distributed import Client

cluster = SLURMCluster(
    queue="regular",
    account="myaccount",
    cores=32,
    memory="128 GB"
)

client = Client(cluster)

# Your abTEM code goes here

Dask also supports deployment from within an existing MPI environment, such as one created with the common MPI command-line launcher mpirun, see here for more information. You can turn your batch Python script into an MPI executable with the dask_mpi.initialize function.

from dask_mpi import initialize
initialize()

from dask.distributed import Client
client = Client()  # Connect this local process to remote workers

# Your abTEM code goes here

This makes your Python script launchable directly with mpirun:

mpirun -np 4 python my_client_script.py

Using GPUs#

Almost every part of abTEM can be accelerated using a GPU through the CuPy library. We have only tested CUDA-compatible GPUs, but any graphics card compatible with CuPy should work.

If you have a compatible GPU and a working installation of CuPy, you can accelerate your image simulations by simply changing the configs at the top of your document as below:

abtem.config.set({"device": "gpu"});

On a single GPU, you may also want to limit the number of workers to one, num_workers=1. As described above, this may be done through the Dask config:

dask.config.set({"num_workers": 1});

Note that Dask does not manage GPU threads. This makes the choice of batch sizes (i.e. propagating multiple wave functions in a single batch) extremely important in order to fully utilize your GPU. The default batch size in GPU calculations in abTEM is \(512\ \mathrm{MB}\), which is four times larger than the CPU batch size, but if your GPU has \(8\ \mathrm{GB}\) or more memory, you will likely be able to increase this number to \(2048\ \mathrm{MB}\).

abtem.config.set({"dask.chunk-size-gpu" : "512 MB"})

Note

The batch size only determines the maximum number of plane waves in a batch, so you need to leave room in the memory for any intermediate overhead.

Finally, abTEM by default sets the FFT plan cache size of cupy to zero, as we find that in most cases the increased memory consumption of the plans are not worth the small speedup they provide. You can however also change this through the abTEM config.

abtem.config.set({"cupy.fft-cache-size" : "1024 MB"})

Multiple GPUs#

The above is enough for running abTEM on a single GPU; if you are using an NVidia GPU, you may want to install dask_cuda (this is currently only supported on Linux). However, dask_cuda is currently necessary for multi-GPU calculations with abTEM.

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

cluster = LocalCUDACluster()
client = Client(cluster)