Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
432 changes: 432 additions & 0 deletions benchmarks/gpu_benchmark.py

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions src/alphafold3/common/folding_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,7 @@ def to_dict(
{
'mmcif': template.mmcif,
'queryIndices': list(template.query_to_template_map.keys()),
'templateIndices': (
list(template.query_to_template_map.values()) or None
),
'templateIndices': list(template.query_to_template_map.values()),
}
for template in self._templates
]
Expand Down
2 changes: 1 addition & 1 deletion src/alphafold3/data/msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class Error(Exception):
"""Error indicatating a problem with MSA Search."""
"""Error indicating a problem with MSA Search."""


def _featurize(seq: str, chain_poly_type: str) -> str | list[int]:
Expand Down
2 changes: 1 addition & 1 deletion src/alphafold3/data/tools/nhmmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(

NOTE: The MSA obtained by running against sharded dbs won't be always
exactly the same as the MSA obtained by running against an unsharded db.
This is because of Jackhmmer deduplication logic, which won't spot duplicate
This is because of Nhmmer deduplication logic, which won't spot duplicate
hits across multiple shards. Usually this means that the sharded search
finds more hits (likely bounded by the number of shards), but this should
not pose an issue given how the results are used downstream. The problem is
Expand Down
25 changes: 25 additions & 0 deletions src/alphafold3/model/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md

"""GPU acceleration utilities for AlphaFold 3.

Modules
-------
fused_ops
Memory-efficient fused operations (OuterProductMean scan fusion).
parallel
Multi-GPU inference via jax.pmap for diffusion sample parallelism.
xla_cache
Persistent XLA compilation cache setup and management.
"""

from alphafold3.model.gpu import fused_ops
from alphafold3.model.gpu import parallel
from alphafold3.model.gpu import xla_cache
91 changes: 91 additions & 0 deletions src/alphafold3/model/gpu/fused_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md

"""Memory-efficient fused GPU operations.

OuterProductMean — three-way einsum fusion
------------------------------------------
The standard OuterProductMean compute_chunk creates a large intermediate tensor
of shape [full_N, C_outer, C_outer, chunk] by running two separate einsums:

step1 = einsum('acb,ade->dceb', left_T, right) # [N, C, C, chunk] ~268 MB
step2 = einsum('dceb,cef->dbf', step1, W) + b # [N, chunk, F_out]

At N=1024, C_outer=32, chunk=128 in bfloat16 the intermediate is ~268 MB.

This module replaces both steps with a single three-way einsum:

result = einsum('acb,ade,cef->dbf', left_T, right, W) + b

XLA's einsum optimizer is free to choose the contraction order that minimises
intermediate size. In practice XLA contracts the two C_outer dimensions first,
producing an [msa, chunk, N, F_out] intermediate (~32 MB) before the final MSA
sum. This avoids the C^2 blowup while keeping the operation in a single fused
kernel with no Python-level loop overhead.

Mathematical equivalence
------------------------
Both formulations compute:
result[m, n, f] = Σ_{a, c_l, r} left_T[a, c_l, m] · right[a, n, r] · W[c_l, r, f]
where a=MSA, c_l=left channel, r=right channel, m=chunk, n=full_N, f=output.
"""

import jax.numpy as jnp


def fused_outer_product_chunk(
left_act: jnp.ndarray,
right_act: jnp.ndarray,
output_w: jnp.ndarray,
output_b: jnp.ndarray,
) -> jnp.ndarray:
"""Memory-efficient OuterProductMean chunk via a fused three-way einsum.

Computes the same result as the standard two-einsum implementation but
expresses it as a single einsum, allowing XLA to choose an optimal
intermediate-free contraction order.

Compared to the two-step baseline:
- Eliminates the [full_N, C_outer, C_outer, chunk] intermediate (~268 MB
at N=1024, C_outer=32, chunk=128, bfloat16).
- No Python-level loop, so XLA can fuse the whole operation into one
kernel without per-iteration launch overhead.
- XLA typically contracts C_outer dimensions first, giving an
[msa, chunk, full_N, F_out] intermediate (~32 MB).

Args:
left_act: [num_msa, chunk, C_outer] — left projections for a residue
chunk (the portion produced by inference_subbatch).
right_act: [num_msa, full_N, C_outer] — right projections for all
residues (non-batched, broadcast to every chunk).
output_w: [C_outer, C_outer, F_out] — output projection weights.
output_b: [F_out] — output projection bias.

Returns:
[chunk, full_N, F_out] — outer-product-mean contribution for this chunk,
in the same layout as the standard compute_chunk path.
"""
# Transpose left to [num_msa, C_outer, chunk] to match the standard einsum
# index convention (a=msa, c=C_l, b=chunk).
left_t = jnp.transpose(left_act, (0, 2, 1)) # [msa, C_l, chunk]

# Three-way einsum — equivalent to the two-step baseline but expressed as
# one operation. XLA contracts out the two C_outer indices (c and e)
# before summing over the MSA index (a), avoiding the large intermediate.
# a = num_msa (summed out)
# c = C_outer left (summed out)
# b = chunk
# d = full_N
# e = C_outer right (summed out)
# f = F_out
result = jnp.einsum('acb,ade,cef->dbf', left_t, right_act, output_w)
# result: [full_N, chunk, F_out]

return jnp.transpose(result, (1, 0, 2)) + output_b # [chunk, full_N, F_out]
190 changes: 190 additions & 0 deletions src/alphafold3/model/gpu/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md

"""Multi-GPU inference utilities for AlphaFold 3.

Status: EXPERIMENTAL / NOT YET INTEGRATED
------------------------------------------
The functions in this module provide building blocks for distributing diffusion
samples across multiple GPUs via ``jax.pmap``. They are **not yet wired into
the AlphaFold 3 inference pipeline** — ``model.py``'s ``_sample_diffusion``
method still runs all samples on a single device.

Before using ``make_parallel_diffusion_fn`` in production, note that
``diffusion_head.sample()`` internally uses ``hk.vmap``; stacking ``jax.pmap``
on top requires ``hk.lift()`` to correctly track Haiku module state under
nested JAX transforms (see the Haiku transforms documentation for details).

Overview
--------
AlphaFold 3 generates ``num_samples`` independent structure predictions via its
diffusion head. By default all samples run on a single device (vmapped over the
sample axis). On machines with N GPUs the intent is to split those samples
across devices using ``jax.pmap``, yielding a near-linear speed-up for the
diffusion phase.

Stable / tested utilities in this module
-----------------------------------------
* ``round_samples_to_devices`` — pad sample count for even device splitting.
* ``get_num_devices`` / ``log_device_info`` — device introspection helpers.
* ``_split_along_leading_axis`` / ``_concat_along_leading_axis`` — array
reshape helpers (used and tested independently of Haiku).
"""

import math
from collections.abc import Callable
from typing import Any

from absl import logging
import jax
import jax.numpy as jnp
import numpy as np


PyTree = Any


def round_samples_to_devices(num_samples: int, num_devices: int) -> int:
"""Round ``num_samples`` up so it is evenly divisible by ``num_devices``.

Args:
num_samples: Desired number of diffusion samples.
num_devices: Number of accelerator devices available.

Returns:
The smallest integer >= ``num_samples`` that is divisible by
``num_devices``.
"""
return math.ceil(num_samples / num_devices) * num_devices


def _split_along_leading_axis(
tree: PyTree,
num_devices: int,
) -> PyTree:
"""Reshape leading axis from [N, ...] to [num_devices, N//num_devices, ...].

Args:
tree: A JAX pytree whose arrays all share the same leading dimension N.
num_devices: Number of devices to split across. N must be divisible.

Returns:
A pytree with the leading axis split into [num_devices, N//num_devices].

Raises:
ValueError: If the leading dimension is not divisible by ``num_devices``.
"""
def _split(arr: np.ndarray) -> np.ndarray:
n = arr.shape[0]
if n % num_devices != 0:
raise ValueError(
f'Leading dimension {n} is not divisible by num_devices'
f' {num_devices}.'
)
return arr.reshape((num_devices, n // num_devices) + arr.shape[1:])

return jax.tree_util.tree_map(_split, tree)


def _concat_along_leading_axis(tree: PyTree) -> PyTree:
"""Merge [num_devices, shard_size, ...] back to [N, ...].

Args:
tree: A JAX pytree whose arrays all have shape
[num_devices, shard_size, ...].

Returns:
A pytree with arrays of shape [N, ...] where N = num_devices * shard_size.
"""
def _concat(arr: np.ndarray) -> np.ndarray:
return arr.reshape((-1,) + arr.shape[2:])

return jax.tree_util.tree_map(_concat, tree)


def make_parallel_diffusion_fn(
single_device_fn: Callable[..., PyTree],
) -> Callable[..., PyTree]:
"""Wrap a single-device diffusion sampling function for multi-GPU execution.

The wrapped function:

1. Expects ``positions`` with a leading ``num_samples`` axis.
2. Reshapes that axis into ``[num_devices, num_samples // num_devices, ...]``.
3. Broadcasts non-sample arguments (keys, masks, embeddings) to all devices.
4. Calls ``jax.pmap`` of ``single_device_fn`` across devices.
5. Concatenates per-device outputs back into a single leading ``num_samples``
axis and returns the result on the first device.

Args:
single_device_fn: A function that runs diffusion sampling for a sub-batch
of samples on a single device. Its first argument must be an array with
a leading sample axis.

Returns:
A ``pmap``-wrapped version of ``single_device_fn``.
"""
pmapped = jax.pmap(single_device_fn, axis_name='devices')

def parallel_fn(positions: np.ndarray, *args: Any, **kwargs: Any) -> PyTree:
num_devices = jax.device_count()
num_samples = positions.shape[0]

if num_samples % num_devices != 0:
raise ValueError(
f'num_samples ({num_samples}) must be divisible by the number of'
f' available devices ({num_devices}). Use'
f' round_samples_to_devices() to pad.'
)

# Split sample axis across devices: [N, ...] -> [D, N/D, ...]
positions_split = _split_along_leading_axis(positions, num_devices)

logging.info(
'Multi-GPU diffusion: %d samples across %d devices (%d per device)',
num_samples,
num_devices,
num_samples // num_devices,
)

result_split = pmapped(positions_split, *args, **kwargs)

# Gather results: [D, N/D, ...] -> [N, ...]
return _concat_along_leading_axis(result_split)

return parallel_fn


def get_num_devices() -> int:
"""Return the number of JAX-visible accelerator devices.

Returns:
Integer count of available GPUs (or TPUs). Falls back to 1 if only a
CPU backend is present.
"""
devices = jax.devices()
gpu_devices = [d for d in devices if d.platform in ('gpu', 'tpu')]
n = len(gpu_devices) if gpu_devices else 1
if n > 1:
logging.info('Multi-GPU mode: %d devices available.', n)
return n


def log_device_info() -> None:
"""Log available JAX devices and their memory capacities."""
devices = jax.devices()
logging.info('JAX device count: %d', len(devices))
for i, dev in enumerate(devices):
logging.info(
' Device %d: platform=%s, device_kind=%s',
i,
dev.platform,
dev.device_kind,
)
Loading