diff --git a/HISTORY.rst b/HISTORY.rst index 26a8f117..431fc987 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History X.Y.Z (YYYY-MM-DD) ------------------ +* Add experimental fragments functionality (:pr:`282`) * Run CI weekly on Monday @ 2h30 am UTC (:pr:`288`) * Update minio server and client versions (:pr:`287`) * Retain ROWID coordinates during MS conversion (:pr:`286`) diff --git a/daskms/apps/fragments.py b/daskms/apps/fragments.py new file mode 100644 index 00000000..0166a30d --- /dev/null +++ b/daskms/apps/fragments.py @@ -0,0 +1,64 @@ +import click +import dask +from daskms.fsspec_store import DaskMSStore +from daskms.experimental.fragments import get_ancestry +from daskms.experimental.zarr import xds_to_zarr, xds_from_zarr + + +@click.group(help="Base command for interacting with fragments.") +def fragments(): + pass + + +@click.command(help="List fragment and parents.") +@click.argument( + "fragment_path", + type=DaskMSStore, +) +@click.option( + "-p/-np", + "--prune/--no-prune", + default=False, +) +def stat(fragment_path, prune): + + ancestors = get_ancestry(fragment_path, only_required=prune) + + click.echo("Ancestry:") + + for i, fg in enumerate(ancestors): + if i == 0: + click.echo(f" {fg.full_path} ---> root") + elif i == len(ancestors) - 1: + click.echo(f" {fg.full_path} ---> target") + else: + click.echo(f" {fg.full_path}") + + +@click.command(help="Change fragment parent.") +@click.argument( + "fragment_path", + type=DaskMSStore, +) +@click.argument( + "parent_path", + type=DaskMSStore, +) +def rebase(fragment_path, parent_path): + xdsl = xds_from_zarr(fragment_path, columns=[]) + + xdsl = [ + xds.assign_attrs({"__dask_ms_parent_url__": parent_path.url}) for xds in xdsl + ] + + writes = xds_to_zarr(xdsl, fragment_path) + + dask.compute(writes) + + +fragments.add_command(stat) +fragments.add_command(rebase) + + +def main(): + fragments() diff --git a/daskms/experimental/fragments/__init__.py b/daskms/experimental/fragments/__init__.py new file mode 100644 index 00000000..b0d5bd15 --- /dev/null +++ b/daskms/experimental/fragments/__init__.py @@ -0,0 +1,237 @@ +from daskms import xds_from_storage_table +from daskms.fsspec_store import DaskMSStore +from daskms.utils import requires +from daskms.experimental.zarr import xds_to_zarr +from zarr.errors import GroupNotFoundError + +try: + import xarray # noqa +except ImportError as e: + xarray_import_error = e +else: + xarray_import_error = None + +xarray_import_msg = "pip install dask-ms[xarray] for xarray support" + + +def get_ancestry(store, only_required=True): + """Produces a list of stores needed to reconstruct the dataset at store.""" + + fragments = [] + + if not isinstance(store, DaskMSStore): + # TODO: Where, when and how should we pass storage options? + store = DaskMSStore(store) + + while True: + root_store = DaskMSStore(store.root_url) + + if store.exists(): + try: + # Store exists and can be read. + xdsl = xds_from_storage_table(store, columns=[]) + fragments += [store] + except GroupNotFoundError: + # Store exists, but cannot be read. We may be dealing with + # a subtable only fragment. NOTE: This assumes that all + # subtables in a fragment have the same parent, so we don't + # care which subtable we read. + subtable_name = store.subdirectories()[0].rsplit("/")[-1] + subtable_store = store.subtable_store(subtable_name) + xdsl = xds_from_storage_table(subtable_store, columns=[]) + fragments += [] if only_required else [subtable_store] + elif root_store.exists(): + # Root store exists and can be read. + xdsl = xds_from_storage_table(root_store, columns=[]) + fragments += [] if only_required else [root_store] + else: + raise FileNotFoundError(f"No root/fragment found at {store}.") + + subtable = store.table + + parent_urls = {xds.attrs.get("__dask_ms_parent_url__", None) for xds in xdsl} + + assert ( + len(parent_urls) == 1 + ), "Fragment has more than one parent - this is not supported." + + parent_url = parent_urls.pop() + + if parent_url: + if not isinstance(parent_url, DaskMSStore): + # TODO: Where, when and how should we pass storage options? + store = DaskMSStore(parent_url).subtable_store(subtable or "") + else: + if store.table and not any(f.table for f in fragments): + # If we are attempting to open a subtable, we don't know if + # it exists until we have traversed the entire ancestry. + raise FileNotFoundError( + f"{store.table} subtable was not found in parents." + ) + + return fragments[::-1] # Flip so that the root appears first. + + +@requires(xarray_import_msg, xarray_import_error) +def consolidate(xdsl): + """ + Consolidates a list of xarray datasets by assigning data variables. + Priority is determined by the position within the list, with elements at + the end of the list having higher priority than those at the start. The + primary purpose of this function is the construction of a consolidated + dataset from a root and deltas (fragments). + + Parameters + ---------- + xdsl : tuple or list + Tuple or list of :class:`xarray.Dataset` objects to consolidate. + + Returns + ------- + consolidated_xds : :class:`xarray.Dataset` + A single :class:`xarray.Dataset`. + """ + + root_xds = xdsl[0] # First element is the root for this operation. + + root_schema = root_xds.__daskms_partition_schema__ + root_partition_keys = {p[0] for p in root_schema} + + consolidated_xds = root_xds # Will be replaced in the loop. + + for xds in xdsl[1:]: + xds_schema = xds.__daskms_partition_schema__ + xds_partition_keys = {p[0] for p in xds_schema} + + if root_partition_keys.symmetric_difference(xds_partition_keys): + raise ValueError( + f"consolidate failed due to conflicting partition keys. " + f"This usually means the partition keys of the fragments " + f"are inconsistent with the current group_cols argument. " + f"Current group_cols produces {root_partition_keys} but " + f"the fragment has {xds_partition_keys}." + ) + + consolidated_xds = consolidated_xds.assign(xds.data_vars) + + return consolidated_xds + + +@requires(xarray_import_msg, xarray_import_error) +def xds_from_ms_fragment(store, **kwargs): + """ + Creates a list of xarray datasets representing the contents a composite + Measurement Set. The resulting list of datasets will consist of some root + dataset with any newer variables populated from the child fragments. It + defers to :func:`xds_from_table_fragment`, which should be consulted + for more information. + + Parameters + ---------- + store : str or DaskMSStore + Store or string of the child fragment of interest. + columns : tuple or list, optional + Columns present on the resulting dataset. + Defaults to all if ``None``. + index_cols : tuple or list, optional + Sequence of indexing columns. + Defaults to :code:`%(indices)s` + group_cols : tuple or list, optional + Sequence of grouping columns. + Defaults to :code:`%(groups)s` + **kwargs : optional + + Returns + ------- + datasets : list of :class:`xarray.Dataset` + xarray datasets for each group + """ + + return xds_from_table_fragment(store, **kwargs) + + +@requires(xarray_import_msg, xarray_import_error) +def xds_from_table_fragment(store, **kwargs): + """ + Creates a list of xarray datasets representing the contents a composite + Measurement Set. The resulting list of datasets will consist of some root + dataset with any newer variables populated from the child fragments. It + defers to :func:`xds_from_storage_ms`, which should be consulted + for more information. + + Parameters + ---------- + store : str or DaskMSStore + Store or string of the child fragment of interest. + columns : tuple or list, optional + Columns present on the resulting dataset. + Defaults to all if ``None``. + index_cols : tuple or list, optional + Sequence of indexing columns. + Defaults to :code:`%(indices)s` + group_cols : tuple or list, optional + Sequence of grouping columns. + Defaults to :code:`%(groups)s` + **kwargs : optional + + Returns + ------- + datasets : list of :class:`xarray.Dataset` + xarray datasets for each group + """ + + ancestors = get_ancestry(store) + + lxdsl = [xds_from_storage_table(s, **kwargs) for s in ancestors] + + return [consolidate(xdss) for xdss in zip(*lxdsl)] + + +@requires(xarray_import_msg, xarray_import_error) +def xds_to_table_fragment(xds, store, parent, **kwargs): + """ + Generates a list of Datasets representing write operations from the + specified arrays in :class:`xarray.Dataset`'s into a child fragment + dataset. + + Parameters + ---------- + xds : :class:`xarray.Dataset` or list of :class:`xarray.Dataset` + dataset(s) containing the specified columns. If a list of datasets + is provided, the concatenation of the columns in + sequential datasets will be written. + store : str or DaskMSStore + Store or string which determines the location to which the child + fragment will be written. + parent : str or DaskMSStore + Store or sting corresponding to the parent dataset. Can be either + point to either a root dataset or another child fragment. + + **kwargs : optional arguments. See :func:`xds_to_table`. + + Returns + ------- + write_datasets : list of :class:`xarray.Dataset` + Datasets containing arrays representing write operations + into a CASA Table + table_proxy : :class:`daskms.TableProxy`, optional + The Table Proxy associated with the datasets + """ + + # TODO: Where, when and how should we pass storage options? + if not isinstance(parent, DaskMSStore): + parent = DaskMSStore(parent) + + # TODO: Where, when and how should we pass storage options? + if not isinstance(store, DaskMSStore): + store = DaskMSStore(store) + + if parent == store: + raise ValueError( + "store and parent arguments identical in xds_to_table_fragment. " + "This is unsupported i.e. a fragment cannot be its own parent. " + ) + + xds = [x.assign_attrs({"__dask_ms_parent_url__": parent.url}) for x in xds] + + return xds_to_zarr(xds, store, **kwargs) diff --git a/daskms/experimental/fragments/tests/__init__.py b/daskms/experimental/fragments/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/daskms/experimental/fragments/tests/test_fragments.py b/daskms/experimental/fragments/tests/test_fragments.py new file mode 100644 index 00000000..557f08fc --- /dev/null +++ b/daskms/experimental/fragments/tests/test_fragments.py @@ -0,0 +1,346 @@ +import pytest +import dask +import dask.array as da +import numpy.testing as npt +from daskms import xds_from_storage_ms, xds_from_storage_table +from daskms.experimental.fragments import xds_to_table_fragment, xds_from_table_fragment + +# Prevent warning pollution generated by all calls to xds_from_zarr with +# unsupported kwargs. +pytestmark = pytest.mark.filterwarnings( + "ignore:The following unsupported kwargs were ignored in xds_from_zarr" +) + + +@pytest.fixture( + scope="module", + params=[ + ("DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"), + ("DATA_DESC_ID", "FIELD_ID"), + ("DATA_DESC_ID",), + ], +) +def group_cols(request): + return request.param + + +# -----------------------------MAIN_TABLE_TESTS-------------------------------- + + +def test_fragment_with_noop(ms, tmp_path_factory, group_cols): + """Unchanged data_vars must remain the same when read from a fragment.""" + + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + writes = xds_to_table_fragment(reads, fragment_path, ms, columns=("DATA",)) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment( + fragment_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + for rxds, frxds in zip(reads, fragment_reads): + assert rxds.equals(frxds), "Datasets not identical." + + +def test_fragment_with_update(ms, tmp_path_factory, group_cols): + """Updated data_vars must change when read from a fragment.""" + + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + updates = [ + xds.assign({"DATA": (xds.DATA.dims, da.ones_like(xds.DATA.data))}) + for xds in reads + ] + + writes = xds_to_table_fragment(updates, fragment_path, ms, columns=("DATA",)) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment( + fragment_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + for frxds in fragment_reads: + npt.assert_array_equal(1, frxds.DATA.data) + + +def test_nonoverlapping_parents(ms, tmp_path_factory, group_cols): + """All updated data_vars must change when read from a fragment.""" + + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment0_path = tmp_dir / "fragment0.ms" + fragment1_path = tmp_dir / "fragment1.ms" + + updates = [ + xds.assign({"DATA": (xds.DATA.dims, da.zeros_like(xds.DATA.data))}) + for xds in reads + ] + + writes = xds_to_table_fragment(updates, fragment0_path, ms, columns=("DATA",)) + + dask.compute(writes) + + fragment0_reads = xds_from_table_fragment( + fragment0_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + updates = [ + xds.assign({"UVW": (xds.UVW.dims, da.zeros_like(xds.UVW.data))}) + for xds in fragment0_reads + ] + + writes = xds_to_table_fragment( + updates, fragment1_path, fragment0_path, columns=("UVW",) + ) + + dask.compute(writes) + + fragment1_reads = xds_from_table_fragment( + fragment1_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + for frxds in fragment1_reads: + npt.assert_array_equal(0, frxds.DATA.data) + npt.assert_array_equal(0, frxds.UVW.data) + + +def test_overlapping_parents(ms, tmp_path_factory, group_cols): + """Youngest child takes priority if updated data_vars overlap.""" + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment0_path = tmp_dir / "fragment0.ms" + fragment1_path = tmp_dir / "fragment1.ms" + + updates = [ + xds.assign({"DATA": (xds.DATA.dims, da.ones_like(xds.DATA.data))}) + for xds in reads + ] + + writes = xds_to_table_fragment(updates, fragment0_path, ms, columns=("DATA",)) + + dask.compute(writes) + + fragment0_reads = xds_from_table_fragment( + fragment0_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + updates = [ + xds.assign({"DATA": (xds.DATA.dims, da.zeros_like(xds.DATA.data))}) + for xds in fragment0_reads + ] + + writes = xds_to_table_fragment( + updates, fragment1_path, fragment0_path, columns=("DATA",) + ) + + dask.compute(writes) + + fragment1_reads = xds_from_table_fragment( + fragment1_path, + index_cols=("TIME",), + group_cols=group_cols, + ) + + for frxds in fragment1_reads: + npt.assert_array_equal(0, frxds.DATA.data) + + +def test_inconsistent_partitioning(ms, tmp_path_factory, group_cols): + """Raises a ValueError when parititoning would be inconsistent.""" + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + writes = xds_to_table_fragment(reads, fragment_path, ms, columns=("DATA",)) + + dask.compute(writes) + + with pytest.raises(ValueError, match="consolidate failed"): + xds_from_table_fragment( + fragment_path, + index_cols=("TIME",), + group_cols=(), + ) + + +def test_mutate_parent(ms, tmp_path_factory): + """Raises a ValueError when a fragment would be its own parent.""" + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=("DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"), + ) + + with pytest.raises(ValueError, match="store and parent arguments"): + xds_to_table_fragment(reads, ms, ms, columns=("DATA",)) + + +def test_missing_parent(ms, tmp_path_factory): + """Raises a ValueError when a fragment is missing a parent.""" + + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=("DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"), + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + missing_parent = tmp_dir / "missing.ms" + + writes = xds_to_table_fragment( + reads, fragment_path, missing_parent, columns=("DATA",) + ) + + dask.compute(writes) + + with pytest.raises(FileNotFoundError, match="No root/fragment found at"): + xds_from_table_fragment( + fragment_path, + index_cols=("TIME",), + group_cols=("DATA_DESC_ID", "FIELD_ID", "SCAN_NUMBER"), + ) + + +def test_datavar_in_parent(ms, tmp_path_factory, group_cols): + """Datavars not present in the fragment must be read from the parent.""" + + reads = xds_from_storage_ms( + ms, + index_cols=("TIME",), + group_cols=group_cols, + ) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + writes = xds_to_table_fragment(reads, fragment_path, ms, columns=("DATA",)) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment( + fragment_path, + columns=("UVW",), # Not in fragment. + index_cols=("TIME",), + group_cols=group_cols, + ) + + for rxds, frxds in zip(reads, fragment_reads): + npt.assert_array_equal(rxds.UVW.data, frxds.UVW.data) + + +# ------------------------------SUBTABLE_TESTS--------------------------------- + + +def test_subtable_fragment_with_noop(spw_table, tmp_path_factory): + """Unchanged data_vars must remain the same when read from a fragment.""" + + reads = xds_from_storage_table(spw_table, group_cols=("__row__",)) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + writes = xds_to_table_fragment( + reads, fragment_path, spw_table, columns=("CHAN_FREQ",) + ) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment(fragment_path, group_cols=("__row__",)) + + for rxds, frxds in zip(reads, fragment_reads): + assert rxds.equals(frxds), "Datasets not identical." + + +def test_subtable_fragment_with_update(spw_table, tmp_path_factory): + """Updated data_vars must change when read from a fragment.""" + + reads = xds_from_storage_table(spw_table, group_cols=("__row__",)) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + updates = [ + xds.assign( + {"CHAN_FREQ": (xds.CHAN_FREQ.dims, da.ones_like(xds.CHAN_FREQ.data))} + ) + for xds in reads + ] + + writes = xds_to_table_fragment( + updates, fragment_path, spw_table, columns=("CHAN_FREQ",) + ) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment(fragment_path, group_cols=("__row__",)) + + for frxds in fragment_reads: + npt.assert_array_equal(1, frxds.CHAN_FREQ.data) + + +def test_subtable_datavar_in_parent(spw_table, tmp_path_factory): + """Datavars not present in the fragment must be read from the parent.""" + + reads = xds_from_storage_table(spw_table, group_cols=("__row__",)) + + tmp_dir = tmp_path_factory.mktemp("fragments") + fragment_path = tmp_dir / "fragment.ms" + + writes = xds_to_table_fragment( + reads, fragment_path, spw_table, columns=("CHAN_FREQ",) + ) + + dask.compute(writes) + + fragment_reads = xds_from_table_fragment( + fragment_path, columns=("NUM_CHAN",), group_cols=("__row__",) + ) + + for rxds, frxds in zip(reads, fragment_reads): + npt.assert_array_equal(rxds.NUM_CHAN.data, frxds.NUM_CHAN.data) + + +# ----------------------------------------------------------------------------- diff --git a/daskms/fsspec_store.py b/daskms/fsspec_store.py index 56acf2a9..2c35bdc6 100644 --- a/daskms/fsspec_store.py +++ b/daskms/fsspec_store.py @@ -94,6 +94,10 @@ def assert_type(self, store_type): def url(self): return f"{self.fs.unstrip_protocol(self.canonical_path)}" + @property + def root_url(self): + return f"{self.fs.unstrip_protocol(self.root)}" + def subdirectories(self): return [ d["name"] diff --git a/pyproject.toml b/pyproject.toml index 9bdff572..abacdaf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ pytest = {version = "^7.1.3", optional=true} [tool.poetry.scripts] dask-ms = "daskms.apps.entrypoint:main" +fragments = "daskms.apps.fragments:main" [tool.poetry.extras] arrow = ["pyarrow"] @@ -29,7 +30,7 @@ xarray = ["xarray"] zarr = ["zarr"] s3 = ["s3fs"] complete = ["s3fs", "pyarrow", "xarray", "zarr"] -testing = ["minio", "pytest"] +testing = ["minio", "pytest", "xarray"] [tool.poetry.group.dev.dependencies] tbump = "^6.9.0"