Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion tests/tsfc/test_underintegration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from ufl import (Mesh, FunctionSpace, TestFunction, TrialFunction, TensorProductCell, dx,
action, interval, quadrilateral, dot, grad)
action, interval, quadrilateral, hexahedron, dot, grad)
from finat.ufl import FiniteElement, VectorElement

from FIAT import ufc_cell
Expand Down Expand Up @@ -100,6 +100,22 @@ def test_laplace(cell, order):
assert (rates < order).all()


@pytest.mark.parametrize(('cell', 'order'),
[(quadrilateral, 3),
(hexahedron, 4)])
def test_laplace_action(cell, order):
# The matrix-free action of the collocated GLL Laplacian must sum factorise:
# flops should grow no faster than O(p^{d+1}), i.e. rate < d + 1. On a
# genuine d-way tensor-product cell (e.g. hexahedron) the value tabulation
# is the identity but was previously materialised as a dense tabulation
# broadcast over the other quadrature directions, which defeated sum
# factorisation and made the action scale like O(p^{2d}) (rate ~5.8 in 3D).
degrees = numpy.arange(4, 10)
flops = [count_flops(action(laplace(cell, int(degree)))) for degree in degrees]
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1))
assert (rates < order).all()


if __name__ == "__main__":
import os
import sys
Expand Down
202 changes: 198 additions & 4 deletions tsfc/spectral.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import OrderedDict, defaultdict, namedtuple
from functools import partial
from functools import partial, singledispatch
from itertools import chain, zip_longest

from gem.gem import Delta, Indexed, Sum, index_sum, one
from gem.node import Memoizer, MemoizerArg
import numpy

from gem.gem import (ComponentTensor, Delta, Index, Indexed, IndexSum,
Literal, Node, Sum, index_sum, one)
from gem.node import Memoizer, MemoizerArg, reuse_if_untouched, traversal
from gem.optimise import filtered_replace_indices
from gem.optimise import delta_elimination as _delta_elimination
from gem.optimise import replace_division, unroll_indexsum
Expand All @@ -18,6 +21,183 @@
'argument_indices'])


# -- Exposing collocation structure to sum factorisation ---------------------
#
# On tensor-product cells with a collocated quadrature rule (e.g. a
# ``variant="spectral"`` element integrated at its own Gauss-Lobatto-Legendre
# nodes), the value tabulation is the identity. However FInAT/TSFC materialise
# the tensor-product tabulation as a dense multi-dimensional ``Literal`` that
# factors as ``T[i, q_own, q_a, q_b] = factor[i, q_own] * const`` -- i.e. a
# genuine 1D factor spuriously *broadcast* (constant) over the other quadrature
# directions. That broadcast hides both the low-rank structure and the
# collocation identity from sum factorisation/delta elimination, so the
# generated operator application scales like O(p^{2d}) instead of O(p^{d+1})
# (the matvec/action of the 3D high order Laplacian was ~5x slower than it
# should be as a result).
#
# The two passes below recover the structure with purely local, exact GEM
# rewrites applied before sum factorisation:
# * ``drop_constant_literal_axes`` removes running indices on axes along which
# a tabulation literal is constant (the broadcast), uncovering the 1D
# factors;
# * ``convert_identity_literals`` rewrites a resulting identity tabulation as a
# Kronecker ``Delta`` so the delta elimination cancels the redundant
# interpolation contraction.


def _is_identity_table(array, epsilon):
"""True if ``array`` is (numerically) a square identity matrix."""
if array.ndim != 2 or array.shape[0] != array.shape[1] or array.shape[0] == 0:
return False
return numpy.allclose(array, numpy.eye(array.shape[0], dtype=array.dtype),
rtol=0.0, atol=epsilon)


def _constant_axes(array, epsilon):
"""Axes of ``array`` (length > 1) along which it is constant."""
if array.ndim < 2:
return ()
eps = epsilon * (1.0 + (numpy.abs(array).max() if array.size else 0.0))
axes = []
for axis in range(array.ndim):
if array.shape[axis] <= 1:
continue
spread = numpy.ptp(array, axis=axis)
if (spread.max() if spread.size else 0.0) <= eps:
axes.append(axis)
return tuple(axes)


def _anchored_indices(expressions, epsilon):
"""Indices that are safe to drop from a constant literal axis.

Dropping a running index from a tabulation literal removes that index from
the expression. This is only sound if the index is *anchored*: it must also
occur somewhere other than a constant axis of an ``Indexed(Literal(...))``,
so that it remains present afterwards (otherwise an enclosing
``ComponentTensor``/``IndexSum`` that binds or sums it would be left
referencing a vanished index).

An index is anchored iff it is *introduced* by some node other than via a
constant literal axis. The indices a node introduces are exactly its free
indices minus those of its children; for ``Indexed(Literal(...))`` the
constant axes are excluded.

Indices bound by a ``ComponentTensor``/``IndexSum`` anywhere in the DAG are
never anchored. The anchoring analysis is global, but binding is scoped and
GEM is a shared DAG, so an index can occur non-constantly under one binder
yet appear *only* on a constant literal axis within the scope of another
binder of the same ``Index`` object. Dropping it there would orphan that
binder's multiindex; refusing to drop any bound index avoids this while
still exposing the (free) broadcast quadrature directions we target.
"""
anchored = set()
bound = set()
for node in traversal(expressions):
if isinstance(node, (ComponentTensor, IndexSum)):
bound.update(node.multiindex)
child_free = set()
for child in node.children:
child_free |= set(child.free_indices)
own = set(node.free_indices) - child_free
if isinstance(node, Indexed) and isinstance(node.children[0], Literal):
const = set(_constant_axes(node.children[0].array, epsilon))
own = {index for axis, index in enumerate(node.multiindex)
if isinstance(index, Index) and axis not in const}
anchored |= own
return anchored - bound


@singledispatch
def _drop_constant_axes(node, self):
raise AssertionError("cannot handle type %s" % type(node))


_drop_constant_axes.register(Node)(reuse_if_untouched)


@_drop_constant_axes.register(Indexed)
def _drop_constant_axes_indexed(node, self):
child, = node.children
# If the literal genuinely does not vary along an axis indexed by a running
# index, indexing it there is redundant: drop the axis (its value is any
# slice). Only drop *anchored* indices (those that also occur elsewhere), so
# the dropped index remains present in the expression and any enclosing
# ComponentTensor/IndexSum that references it stays well formed.
if (isinstance(child, Literal) and len(node.multiindex) == child.array.ndim
and child.array.ndim >= 2):
const = set(_constant_axes(child.array, self.epsilon))
slicer = []
new_multiindex = []
dropped = False
for axis, index in enumerate(node.multiindex):
if (axis in const and isinstance(index, Index)
and index in self.anchored):
slicer.append(0) # constant along this axis: keep slice 0
dropped = True
else:
slicer.append(slice(None))
new_multiindex.append(index)
if dropped:
reduced = child.array[tuple(slicer)]
return Indexed(Literal(reduced, dtype=child.dtype),
tuple(new_multiindex))
return reuse_if_untouched(node, self)


def drop_constant_literal_axes(expressions, epsilon=1e-12):
"""Drop running indices of ``Indexed(Literal(...))`` along axes on which the
literal is constant, exposing the underlying low-rank tabulation structure
to sum factorisation.

Only indices that are anchored elsewhere (see :func:`_anchored_indices`) are
dropped, so the rewrite never leaves a dangling index behind.

:arg expressions: iterable of GEM expressions
:arg epsilon: tolerance for recognising a constant axis
"""
expressions = list(expressions)
mapper = Memoizer(_drop_constant_axes)
mapper.epsilon = epsilon
mapper.anchored = _anchored_indices(expressions, epsilon)
return [mapper(e) for e in expressions]


@singledispatch
def _identity_to_delta(node, self):
raise AssertionError("cannot handle type %s" % type(node))


_identity_to_delta.register(Node)(reuse_if_untouched)


@_identity_to_delta.register(Indexed)
def _identity_to_delta_indexed(node, self):
child, = node.children
# A collocated tabulation matrix (CG nodal values at the collocated
# quadrature points) is the identity. Indexing it with two distinct running
# indices is exactly a Kronecker delta; rewriting it as such lets sum
# factorisation/delta elimination cancel the redundant contraction.
if (isinstance(child, Literal) and len(node.multiindex) == 2
and all(isinstance(i, Index) for i in node.multiindex)
and _is_identity_table(child.array, self.epsilon)):
i, j = node.multiindex
return Delta(i, j)
return reuse_if_untouched(node, self)


def convert_identity_literals(expressions, epsilon=1e-12):
"""Rewrite ``Indexed(Literal(I), (i, j))`` as ``Delta(i, j)`` for identity
tabulation matrices, exposing collocation structure to sum factorisation.
Comment on lines +190 to +191

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to be adding a translator from an identity Literal into Delta, since by design we should not have generated the Literal in the first place.

GLL elements should tabulate to a gem.Delta. If this is not the case, then either the GLL rule is not being properly constructed/detected.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix should go in finat (under the FIAT repo). Ask Claude to attempt to reproduce an MFE by tabulating a hexahedral GLL element on a GLL quadrature, this should give a gem.Delta.


:arg expressions: iterable of GEM expressions
:arg epsilon: tolerance for recognising an identity matrix
"""
mapper = Memoizer(_identity_to_delta)
mapper.epsilon = epsilon
return [mapper(e) for e in expressions]


def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters):
"""Constructs an integral representation for each GEM integrand
expression.
Expand All @@ -34,6 +214,17 @@ def Integrals(expressions, quadrature_multiindex, argument_multiindices, paramet
# Rewrite: a / b => a * (1 / b)
expressions = replace_division(expressions)

# Expose the sum-factorisable structure of tensor-product tabulation
# matrices. First drop running indices on axes where a tabulation literal
# is merely a constant broadcast (a spurious coupling to the other
# quadrature directions); this uncovers the genuine 1D factors. Then
# rewrite any resulting identity tabulation (collocated nodal values) as a
# Kronecker delta so the delta elimination below cancels the redundant
# interpolation contraction. Together these turn the O(p^{2d}) collocated
# operator application into the O(p^{d+1}) sum-factorised form.
expressions = drop_constant_literal_axes(expressions)
expressions = convert_identity_literals(expressions)

# Unroll
max_extent = parameters["unroll_indexsum"]
if max_extent:
Expand Down Expand Up @@ -120,7 +311,10 @@ def group_key(pair):
yield (variable, expression)


finalise_options = dict(replace_delta=False)
# Lower any Deltas that survive sum factorisation (e.g. test-function
# collocation deltas introduced by convert_identity_literals that could not be
# cancelled against a sum index) to identity indexing for code generation.
finalise_options = dict(replace_delta=True)


def classify(argument_indices, expression, delta_inside):
Expand Down
Loading