diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fb55ab08..e3c00650 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -30,6 +30,7 @@ validate_table_attr_keys, ) from spatialdata._logging import logger +from spatialdata._store import PathLike, normalize_path, open_read_store, open_write_store, path_from_store from spatialdata._types import ArrayLike, Raster_T from spatialdata._utils import _deprecation_alias from spatialdata.models import ( @@ -121,7 +122,7 @@ def __init__( tables: dict[str, AnnData] | Tables | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: - self._path: Path | None = None + self._path: Path | UPath | None = None self._shared_keys: set[str | None] = set() self._images: Images = Images(shared_keys=self._shared_keys) @@ -548,16 +549,18 @@ def is_backed(self) -> bool: return self.path is not None @property - def path(self) -> Path | None: - """Path to the Zarr storage.""" + def path(self) -> Path | UPath | None: + """Path to the Zarr storage (always :class:`pathlib.Path` or :class:`upath.UPath` when set).""" return self._path @path.setter - def path(self, value: Path | None) -> None: - if value is None or isinstance(value, str | Path): - self._path = value - else: - raise TypeError("Path must be `None`, a `str` or a `Path` object.") + def path(self, value: str | Path | UPath | None) -> None: + self._path = None if value is None else normalize_path(value) + + def _require_path(self) -> PathLike: + if self._path is None: + raise ValueError("The SpatialData object is not backed by a Zarr store.") + return self._path def locate_element(self, element: SpatialElement) -> list[str]: """ @@ -987,13 +990,7 @@ def elements_paths_on_disk(self) -> list[str]: ------- A list of paths of the elements saved in the Zarr store. """ - from spatialdata._io._utils import _resolve_zarr_store - - if self.path is None: - raise ValueError("The SpatialData object is not backed by a Zarr store.") - - store = _resolve_zarr_store(self.path) - root = zarr.open_group(store=store, mode="r") + zarr_store = self._require_path() elements_in_zarr = [] def find_groups(obj: zarr.Group, path: str) -> None: @@ -1002,13 +999,14 @@ def find_groups(obj: zarr.Group, path: str) -> None: if isinstance(obj, zarr.Group) and path.count("/") == 1: elements_in_zarr.append(path) - for element_type in root: - if element_type in ["images", "labels", "points", "shapes", "tables"]: - for element_name in root[element_type]: - path = f"{element_type}/{element_name}" - elements_in_zarr.append(path) + with open_read_store(zarr_store) as store: + root = zarr.open_group(store=store, mode="r") + for element_type in root: + if element_type in ["images", "labels", "points", "shapes", "tables"]: + for element_name in root[element_type]: + path = f"{element_type}/{element_name}" + elements_in_zarr.append(path) # root.visit(lambda path: find_groups(root[path], path)) - store.close() return elements_in_zarr def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: @@ -1037,18 +1035,56 @@ def _symmetric_difference_with_zarr_store(self) -> tuple[list[str], list[str]]: def _validate_can_safely_write_to_path( self, - file_path: str | Path, + file_path: str | Path | UPath, overwrite: bool = False, saving_an_element: bool = False, ) -> None: - from spatialdata._io._utils import _backed_elements_contained_in_path, _is_subfolder, _resolve_zarr_store + """ + Guard against unsafe writes for **local** paths (zarr check, Dask backing, subfolders). + + For :class:`upath.UPath`, ``overwrite=False`` is rejected: we cannot reliably check + whether a remote store already exists (fsspec existence semantics vary by backend and + object stores have no directory concept), so the "fail if exists" contract cannot be + honored. Callers must pass ``overwrite=True`` to explicitly acknowledge that the write + may clobber pre-existing data at the target. + """ + from upath.implementations.local import PosixUPath, WindowsUPath - if isinstance(file_path, str): + from spatialdata._io._utils import ( + _backed_elements_contained_in_path, + _is_subfolder, + _resolve_zarr_store, + ) + + # Hierarchical URIs ("scheme://...") must become UPath: plain Path(str) breaks cloud URLs + # (S3-compatible stores, Azure abfs:// / az://, GCS gs://, https://, fsspec chains, etc.). + if isinstance(file_path, str) and "://" in file_path: + file_path = UPath(file_path) + elif isinstance(file_path, str): file_path = Path(file_path) - if not isinstance(file_path, Path): - raise ValueError(f"file_path must be a string or a Path object, type(file_path) = {type(file_path)}.") + if not isinstance(file_path, (Path, UPath)): + raise ValueError(f"file_path must be a `str`, `Path`, or `UPath` object, got {type(file_path).__name__}.") + + # Local UPath variants (PosixUPath / WindowsUPath) wrap a plain filesystem path; they + # have reliable existence semantics and must go through the same local validation as + # Path. Only *remote* UPath (cloud / http / memory / etc.) falls through the remote guard. + is_remote_upath = isinstance(file_path, UPath) and not isinstance(file_path, (PosixUPath, WindowsUPath)) + + if is_remote_upath: + # The overwrite opt-in only applies at the top-level store entry. Per-element writes + # issued internally by ``write()`` (and incremental ``write_element`` calls into an + # existing store) must not re-trigger the guard on every sub-key, or writing to a + # remote target would be impossible. + if not overwrite and not saving_an_element: + raise NotImplementedError( + "Writing to a remote (UPath) target requires overwrite=True. " + "We cannot reliably check whether the remote store already exists, so the write " + "may clobber existing data; pass overwrite=True to acknowledge this." + ) + return + # Local Path: existing logic # TODO: add test for this if os.path.exists(file_path): store = _resolve_zarr_store(file_path) @@ -1077,8 +1113,13 @@ def _validate_can_safely_write_to_path( ERROR_MSG + "\nDetails: the target path contains one or more files that Dask use for " "backing elements in the SpatialData object." + WORKAROUND ) - if self.path is not None and ( - _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + # Subfolder checks only for local paths (Path); skip when self.path is UPath + if ( + self.path is not None + and isinstance(self.path, Path) + and ( + _is_subfolder(parent=self.path, child=file_path) or _is_subfolder(parent=file_path, child=self.path) + ) ): if saving_an_element and _is_subfolder(parent=self.path, child=file_path): raise ValueError( @@ -1107,7 +1148,7 @@ def _validate_all_elements(self) -> None: @_deprecation_alias(format="sdata_formats", version="0.7.0") def write( self, - file_path: str | Path, + file_path: str | Path | zarr.storage.StoreLike, overwrite: bool = False, consolidate_metadata: bool = True, update_sdata_path: bool = True, @@ -1121,7 +1162,14 @@ def write( Parameters ---------- file_path - The path to the Zarr store to write to. + Where to write. One of: + + - A local filesystem path (``str`` or :class:`pathlib.Path`) + - A zarr store (e.g. :class:`zarr.storage.LocalStore`, + :class:`zarr.storage.FsspecStore`) carrying its own filesystem (and + credentials) — the supported form for remote backends like + S3 / Azure / GCS. Stores without a filesystem path (e.g. + :class:`zarr.storage.MemoryStore`) are not currently supported. overwrite If `True`, overwrite the Zarr store if it already exists. If `False`, `write()` will fail if the Zarr store already exists. @@ -1167,22 +1215,34 @@ def write( supported. If not specified, the compression will be `lz4` with compression level 5. Bytes are automatically ordered for more efficient compression. """ - from spatialdata._io._utils import _resolve_zarr_store, _validate_compressor_args + from spatialdata._io._utils import _validate_compressor_args from spatialdata._io.format import _parse_formats parsed = _parse_formats(sdata_formats) _validate_compressor_args(raster_compressor) - if isinstance(file_path, str): - file_path = Path(file_path) + # Resolve all input forms (str / Path / StoreLike) to a path the internal per-element + # write machinery can use. For zarr stores, derive a backing path via path_from_store; + # stores without a filesystem path (e.g. MemoryStore) are rejected here because the + # per-element machinery currently re-opens stores from the path. + if isinstance(file_path, zarr.abc.store.Store): + derived = path_from_store(file_path) + if derived is None: + raise NotImplementedError( + f"Writing to a store of type {type(file_path).__name__} is not supported " + "because it does not expose a filesystem path. Pass a LocalStore, FsspecStore, " + "or a path/UPath instead." + ) + file_path = derived + else: + file_path = normalize_path(file_path) self._validate_can_safely_write_to_path(file_path, overwrite=overwrite) self._validate_all_elements() - store = _resolve_zarr_store(file_path) - zarr_format = parsed["SpatialData"].zarr_format - zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) - self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) - store.close() + with open_write_store(file_path) as store: + zarr_format = parsed["SpatialData"].zarr_format + zarr_group = zarr.create_group(store=store, overwrite=overwrite, zarr_format=zarr_format) + self.write_attrs(zarr_group=zarr_group, sdata_format=parsed["SpatialData"]) for element_type, element_name, element in self.gen_elements(): self._write_element( @@ -1197,7 +1257,7 @@ def write( ) if self.path != file_path and update_sdata_path: - self.path = file_path + self._path = file_path if consolidate_metadata: self.write_consolidated_metadata() @@ -1205,7 +1265,7 @@ def write( def _write_element( self, element: SpatialElement | AnnData, - zarr_container_path: Path, + zarr_container_path: Path | UPath, element_type: str, element_name: str, overwrite: bool, @@ -1215,9 +1275,9 @@ def _write_element( ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element - if not isinstance(zarr_container_path, Path): + if not isinstance(zarr_container_path, (Path, UPath)): raise ValueError( - f"zarr_container_path must be a Path object, type(zarr_container_path) = {type(zarr_container_path)}." + f"zarr_container_path must be a `Path` or `UPath` object, got {type(zarr_container_path).__name__}." ) file_path_of_element = zarr_container_path / element_type / element_name self._validate_can_safely_write_to_path( @@ -1452,13 +1512,12 @@ def delete_element_from_disk(self, element_name: str | list[str]) -> None: "more elements in the SpatialData object. Deleting the data would corrupt the SpatialData object." ) - from spatialdata._io._utils import _resolve_zarr_store + zarr_store = self._require_path() # delete the element - store = _resolve_zarr_store(self.path) - root = zarr.open_group(store=store, mode="r+", use_consolidated=False) - del root[element_type][element_name] - store.close() + with open_write_store(zarr_store) as store: + root = zarr.open_group(store=store, mode="r+", use_consolidated=False) + del root[element_type][element_name] if self.has_consolidated_metadata(): self.write_consolidated_metadata() @@ -1481,14 +1540,11 @@ def write_consolidated_metadata(self) -> None: _write_consolidated_metadata(self.path) def has_consolidated_metadata(self) -> bool: - from spatialdata._io._utils import _resolve_zarr_store - return_value = False - store = _resolve_zarr_store(self.path) - group = zarr.open_group(store, mode="r") - if getattr(group.metadata, "consolidated_metadata", None): - return_value = True - store.close() + with open_read_store(self._require_path()) as store: + group = zarr.open_group(store, mode="r") + if getattr(group.metadata, "consolidated_metadata", None): + return_value = True return return_value def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[str, SpatialElement | AnnData] | None: @@ -1518,7 +1574,7 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # check if the element exists in the Zarr storage if not _group_for_element_exists( - zarr_path=Path(self.path), + zarr_path=self.path, element_type=element_type, element_name=element_name, ): @@ -1532,7 +1588,7 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st # warn the users if the element is not self-contained, that is, it is Dask-backed by files outside the Zarr # group for the element - element_zarr_path = Path(self.path) / element_type / element_name + element_zarr_path = self.path / element_type / element_name if not _is_element_self_contained(element=element, element_path=element_zarr_path): logger.info( f"Element {element_type}/{element_name} is not self-contained. The metadata will be" @@ -1573,7 +1629,7 @@ def write_channel_names(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have the check in the conditional if element_type == "images" and self.path is not None: _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False + zarr_path=self.path, element_type=element_type, element_name=element_name, use_consolidated=False ) from spatialdata._io._utils import overwrite_channel_names @@ -1582,6 +1638,10 @@ def write_channel_names(self, element_name: str | None = None) -> None: else: raise ValueError(f"Can't set channel names for element of type '{element_type}'.") + # See ``write_transformations`` for why this refresh is needed after an in-place attrs write. + if self.has_consolidated_metadata(): + self.write_consolidated_metadata() + def write_transformations(self, element_name: str | None = None) -> None: """ Write transformations to disk for a single element, or for all elements, without rewriting the data. @@ -1617,7 +1677,7 @@ def write_transformations(self, element_name: str | None = None) -> None: # Mypy does not understand that path is not None so we have a conditional assert self.path is not None _, _, element_group = _get_groups_for_element( - zarr_path=Path(self.path), + zarr_path=self.path, element_type=element_type, element_name=element_name, use_consolidated=False, @@ -1646,6 +1706,12 @@ def write_transformations(self, element_name: str | None = None) -> None: else: raise ValueError(f"Unknown element type {type(element)}") + # Consolidated metadata caches every element's attrs at the root; an in-place attrs + # write to an element leaves that cache stale, so the next ``read_zarr`` would return + # the old transformation. Refresh if the store had consolidated metadata. + if self.has_consolidated_metadata(): + self.write_consolidated_metadata() + def _element_type_from_element_name(self, element_name: str) -> str: self._validate_element_names_are_unique() element = self.get(element_name) @@ -1674,18 +1740,17 @@ def write_attrs( sdata_format: SpatialDataContainerFormatType | None = None, zarr_group: zarr.Group | None = None, ) -> None: - from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import CurrentSpatialDataContainerFormat, SpatialDataContainerFormatType sdata_format = sdata_format if sdata_format is not None else CurrentSpatialDataContainerFormat() assert isinstance(sdata_format, SpatialDataContainerFormatType) - store = None - if zarr_group is None: assert self.is_backed(), "The SpatialData object must be backed by a Zarr store to write attrs." - store = _resolve_zarr_store(self.path) - zarr_group = zarr.open_group(store=store, mode="r+") + with open_write_store(self._require_path()) as store: + zarr_group = zarr.open_group(store=store, mode="r+") + self.write_attrs(sdata_format=sdata_format, zarr_group=zarr_group) + return version = sdata_format.spatialdata_format_version version_specific_attrs = sdata_format.attrs_to_dict() @@ -1696,9 +1761,6 @@ def write_attrs( except TypeError as e: raise TypeError("Invalid attribute in SpatialData.attrs") from e - if store is not None: - store.close() - def write_metadata( self, element_name: str | None = None, @@ -1985,7 +2047,8 @@ def h(s: str) -> str: descr = "SpatialData object" if self.path is not None: - descr += f", with associated Zarr store: {self.path.resolve()}" + path_descr = str(self.path) if isinstance(self.path, UPath) else self.path.resolve() + descr += f", with associated Zarr store: {path_descr}" non_empty_elements = self._non_empty_elements() last_element_index = len(non_empty_elements) - 1 diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 3be56d67..904a9856 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import filecmp +import json import os.path import re import sys @@ -11,7 +12,7 @@ from contextlib import contextmanager from enum import Enum from functools import singledispatch -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import Any, Literal import zarr @@ -27,7 +28,6 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._io.format import RasterFormatType, RasterFormatV01, RasterFormatV02, RasterFormatV03 -from spatialdata._logging import logger from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -39,6 +39,12 @@ from spatialdata.transformations.transformations import BaseTransformation, _get_current_output_axes +def join_fsspec_store_path(store_path: str, relative_path: str) -> str: + """Append a relative zarr-group path to an FsspecStore root, yielding a fsspec key.""" + rel = relative_path.lstrip("/") + return str(PurePosixPath(store_path) / rel) if rel else store_path + + def _get_transformations_from_ngff_dict( list_of_encoded_ngff_transformations: list[dict[str, Any]], ) -> MappingToCoordinateSystem_t: @@ -318,6 +324,83 @@ def _find_piece_dict(obj: dict[str, tuple[str | None]] | Task) -> dict[str, tupl return None +def _extract_parquet_paths_from_task(obj: Any) -> list[str]: + """Recursively extract parquet file paths from a dask ``read_parquet`` task. + + Dask's task-graph shape changed between the version pinned before scverse/spatialdata + PR #1006 (https://github.com/scverse/spatialdata/pull/1006 "unpinning dask", commit + 53b9438a https://github.com/scverse/spatialdata/commit/53b9438a328c5fc2a451d2c8afab439b945ba2b8) + and the current one; we tolerate both. + + - Legacy shape: a dict ``{"piece": (parquet_file, None, None)}`` somewhere in the args + (possibly wrapped in other dicts for mixed points+images element graphs). The trailing + elements of the ``piece`` tuple encode row-group / filter constraints; we only support + unfiltered reads (hence the validation on ``check0`` / ``check1``). + - Current shape: a ``dask.dataframe.dask_expr.io.parquet.FragmentWrapper`` whose + ``.fragment.path`` is the parquet file (from ``dask_expr.io.parquet.ReadParquetPyarrowFS``). + The wrapper may live in Task ``kwargs["fragment_wrapper"]`` for simple reads, but in fused + expressions (``readparquetpyarrowfs-fused-*``) it is nested inside lists and tuples + inside a subgraph dict, so we walk every container uniformly rather than targeting named + kwargs. + + ``FragmentWrapper`` is detected via the ``.fragment.path`` attribute chain instead of an + isinstance check to avoid importing private dask_expr internals; the ``endswith(".parquet")`` + guard keeps false positives from random objects out of the result. + """ + found: list[str] = [] + + frag = getattr(obj, "fragment", None) + if frag is not None: + path = getattr(frag, "path", None) + if isinstance(path, str) and path.endswith(".parquet"): + found.append(path) + + if isinstance(obj, Mapping): + # TODO(legacy-dask): the ``"piece"`` branch targets the pre-PR-#1006 dask graph shape + # (``dask/dataframe/io/parquet/core.py`` produced ``{"piece": (file, rg, filters)}``). The + # current dask pin (``dask>=2025.12.0``) no longer emits this shape at runtime; the branch + # is kept only as a safety net for users forcing an older dask via pip. Remove once the + # lower pin is bumped past the PR-#1006 cut-off and CI covers only the new shape. + if "piece" in obj: + piece = obj["piece"] + # piece is ``(parquet_file, row_groups, filters)`` -- ``row_groups`` and ``filters`` may + # be ``None`` (single-file unfiltered) or ``[0]/[]`` (``aggregate_files=True``); both are + # whole-file reads and not interesting to us here, so we only validate the extension. + if isinstance(piece, tuple) and len(piece) >= 1 and isinstance(piece[0], str): + parquet_file = piece[0] + if not parquet_file.endswith(".parquet"): + raise ValueError( + f"Unable to parse the parquet file from the dask task {obj!r}. Please report this bug." + ) + found.append(parquet_file) + for v in obj.values(): + found.extend(_extract_parquet_paths_from_task(v)) + return found + + if isinstance(obj, (list, tuple)): + for item in obj: + found.extend(_extract_parquet_paths_from_task(item)) + return found + + # TODO(dask-task-api): the ``kwargs`` / ``args`` getattr probes here rely on the Task wrapper + # object introduced alongside PR #1006. The attribute contract is not documented as public + # (``dask.dataframe.dask_expr``), so we access it defensively via getattr and traverse every + # container uniformly. If dask stabilises a public accessor (e.g. ``task.iter_leaves()`` or an + # expr-level ``file_paths`` property) or if ``FragmentWrapper`` becomes importable from a + # stable namespace, replace the attribute-chain walk with a typed call and drop the getattrs. + kwargs = getattr(obj, "kwargs", None) + if isinstance(kwargs, Mapping): + for v in kwargs.values(): + found.extend(_extract_parquet_paths_from_task(v)) + + args = getattr(obj, "args", None) + if isinstance(args, (list, tuple)): + for a in args: + found.extend(_extract_parquet_paths_from_task(a)) + + return found + + def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> None: # see the types allowed for the dask graph here: https://docs.dask.org/en/stable/spec.html @@ -340,66 +423,32 @@ def _search_for_backing_files_recursively(subgraph: Any, files: list[str]) -> No name = k if name is not None: if name.startswith("original-from-zarr"): - # LocalStore.store does not have an attribute path, but we keep it like this for backward compat. + # TODO(zarr-v3-store-path): the ``getattr(..., "path", None)`` fallback dates + # back to zarr v2, where ``DirectoryStore`` exposed ``.path`` and the v3 + # ``LocalStore`` exposes ``.root`` instead. With the current pin + # (``zarr>=3.0.0``) the getattr branch is never taken for local backends -- it + # only covers exotic third-party stores that still mimic the v2 attribute. + # Once we are confident no such shim stores are in use, collapse this to just + # ``v.store.root`` and drop the getattr probe. path = getattr(v.store, "path", None) if getattr(v.store, "path", None) else v.store.root files.append(str(UPath(path).resolve())) - elif name.startswith("read-parquet") or name.startswith("read_parquet"): - # Here v is a read_parquet task with arguments and the only value is a dictionary. - if "piece" in v.args[0]: - # https://github.com/dask/dask/blob/ff2488aec44d641696e0b7aa41ed9e995c710705/dask/dataframe/io/parquet/core.py#L870 - parquet_file, check0, check1 = v.args[0]["piece"] - if not parquet_file.endswith(".parquet") or check0 is not None or check1 is not None: - raise ValueError( - f"Unable to parse the parquet file from the dask subgraph {subgraph}. Please " - f"report this bug." - ) + elif "parquet" in name.lower(): + # Matches every dask task-key that wraps a parquet read across versions: + # - legacy ``read-parquet-`` / ``read_parquet-`` (pre scverse/ + # spatialdata PR #1006, https://github.com/scverse/spatialdata/pull/1006), + # - current ``read_parquet-`` plus fused-expression forms such as + # ``readparquetpyarrowfs-fused-values-`` produced by + # ``dask_expr.io.parquet.ReadParquetPyarrowFS`` when a parquet column is + # combined with other arrays (see ``test_self_contained``). + # Any false-positive key that matches but carries no parquet payload is filtered + # inside ``_extract_parquet_paths_from_task`` (paths must ``endswith(".parquet")``). + for parquet_file in _extract_parquet_paths_from_task(v): files.append(os.path.realpath(parquet_file)) - else: - # This occurs when for example points and images are mixed, the main task still starts with - # read_parquet, but the execution happens through a subgraph which we iterate over to get the - # actual read_parquet task. - # - # v.args[0] has two known shapes: - # dict – keys are task keys, values are Task objects (classic subgraph case) - # list – list of piece dicts produced when aggregate_files=True aggregates multiple - # parquet files into one partition; check0/check1 are row-group selectors - # ([0], []) rather than None, so only the file extension is validated. - args0 = v.args[0] - if isinstance(args0, dict): - for task in args0.values(): - # Recursively go through tasks, this is required because differences between dask - # versions. - piece_dict = _find_piece_dict(task) - if isinstance(piece_dict, dict) and "piece" in piece_dict: - parquet_file, check0, check1 = piece_dict["piece"] # type: ignore[misc] - if ( - not parquet_file.endswith(".parquet") - or check0 is not None - or check1 is not None - ): - raise ValueError( - f"Unable to parse the parquet file from the dask subgraph {subgraph}. " - f"Please report this bug." - ) - files.append(os.path.realpath(parquet_file)) - elif isinstance(args0, list): - for item in args0: - if isinstance(item, dict) and "piece" in item: - parquet_file = item["piece"][0] - if not parquet_file.endswith(".parquet"): - raise ValueError( - f"Unable to parse the parquet file from the dask subgraph {subgraph}. " - f"Please report this bug." - ) - files.append(os.path.realpath(parquet_file)) - else: - logger.warning( - f"Unexpected type {type(args0)} for v.args[0] in the read_parquet task graph. " - f"Backing files may not be detected correctly. Please report this as a bug." - ) - - -def _backed_elements_contained_in_path(path: Path, object: SpatialData | SpatialElement | AnnData) -> list[bool]: + + +def _backed_elements_contained_in_path( + path: Path | UPath, object: SpatialData | SpatialElement | AnnData +) -> list[bool]: """ Return the list of boolean values indicating if backing files for an object are child directory of a path. @@ -418,9 +467,16 @@ def _backed_elements_contained_in_path(path: Path, object: SpatialData | Spatial ----- If an object does not have a Dask computational graph, it will return an empty list. It is possible for a single SpatialElement to contain multiple files in their Dask computational graph. + + For a remote ``path`` (:class:`upath.UPath`), this always returns an empty list: Dask backing paths + are resolved as local filesystem paths, so they cannot be compared to object-store locations. + :meth:`spatialdata.SpatialData.write` therefore skips the local "backing files in target" guard + for remote targets; ``overwrite=True`` on a remote URL must be used only when overwriting is safe. """ + if isinstance(path, UPath): + return [] if not isinstance(path, Path): - raise TypeError(f"Expected a Path object, got {type(path)}") + raise TypeError(f"Expected a Path or UPath object, got {type(path)}") return [_is_subfolder(parent=path, child=Path(fp)) for fp in get_dask_backing_files(object)] @@ -449,16 +505,44 @@ def _is_subfolder(parent: Path, child: Path) -> bool: def _is_element_self_contained( - element: DataArray | DataTree | DaskDataFrame | GeoDataFrame | AnnData, element_path: Path + element: DataArray | DataTree | DaskDataFrame | GeoDataFrame | AnnData, + element_path: Path | UPath, ) -> bool: + """Whether element Dask graphs only reference files under ``element_path`` (local) or N/A (remote).""" + if isinstance(element_path, UPath): + # Backing-file paths are local; cannot relate them to remote keys—assume OK for this heuristic. + return True if isinstance(element, DaskDataFrame): pass # TODO when running test_save_transformations it seems that for the same element this is called multiple times return all(_backed_elements_contained_in_path(path=element_path, object=element)) +def _ensure_async_fs(fs: Any) -> Any: + """Return an async fsspec filesystem for use with zarr's FsspecStore. + + Zarr's FsspecStore expects an async filesystem. If the given fs is synchronous, + it is converted using fsspec's public API (async instance or AsyncFileSystemWrapper) + so that ZarrUserWarning is not raised. + """ + if getattr(fs, "asynchronous", False): + return fs + import fsspec + + if getattr(fs, "async_impl", False): + fs_dict = json.loads(fs.to_json()) + fs_dict["asynchronous"] = True + return fsspec.AbstractFileSystem.from_json(json.dumps(fs_dict)) + from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper + + return AsyncFileSystemWrapper(fs, asynchronous=True) + + def _resolve_zarr_store( - path: str | Path | UPath | zarr.storage.StoreLike | zarr.Group, **kwargs: Any + path: str | Path | UPath | zarr.storage.StoreLike | zarr.Group, + *, + read_only: bool = False, + **kwargs: Any, ) -> zarr.storage.StoreLike: """ Normalize different Zarr store inputs into a usable store instance. @@ -474,9 +558,14 @@ def _resolve_zarr_store( path The input representing a Zarr store or group. Can be a filesystem path, remote path, existing store, or Zarr group. + read_only + If ``True``, constructed ``LocalStore`` / ``FsspecStore`` instances are built with + ``read_only=True``. Stores that already exist (when ``path`` is a ``StoreLike`` or + a ``zarr.Group`` whose wrapped store is not reconstructable) are returned as-is; + the caller is responsible for opening them at the right mode. **kwargs Additional keyword arguments forwarded to the underlying store - constructor (e.g. `mode`, `storage_options`). + constructor. Returns ------- @@ -486,37 +575,39 @@ def _resolve_zarr_store( ------ TypeError If the input type is unsupported. - ValueError + ValueError If a `zarr.Group` has an unsupported store type. """ - # TODO: ensure kwargs like mode are enforced everywhere and passed correctly to the store if isinstance(path, str | Path): - # if the input is str or Path, map it to UPath path = UPath(path) if isinstance(path, PosixUPath | WindowsUPath): # if the input is a local path, use LocalStore - return LocalStore(path.path) + return LocalStore(path.path, read_only=read_only) if isinstance(path, zarr.Group): - # if the input is a zarr.Group, wrap it with a store + # Re-wrap the group's store at the group's subpath. Note: zarr v3 no longer ships + # ``ConsolidatedMetadataStore`` (v2 wrapped the backend in a store; v3 surfaces + # consolidated metadata as a field on ``GroupMetadata`` instead), so we only need to + # handle the two concrete backends below. if isinstance(path.store, LocalStore): store_path = UPath(path.store.root) / path.path - return LocalStore(store_path.path) + return LocalStore(store_path.path, read_only=read_only) if isinstance(path.store, FsspecStore): - # if the store within the zarr.Group is an FSStore, return it - # but extend the path of the store with that of the zarr.Group - return FsspecStore(path.store.path + "/" + path.path, fs=path.store.fs, **kwargs) - if isinstance(path.store, zarr.storage.ConsolidatedMetadataStore): - # if the store is a ConsolidatedMetadataStore, just return the underlying FSSpec store - return path.store.store + return FsspecStore( + fs=_ensure_async_fs(path.store.fs), + path=join_fsspec_store_path(path.store.path, path.path), + read_only=read_only, + **kwargs, + ) raise ValueError(f"Unsupported store type or zarr.Group: {type(path.store)}") - if isinstance(path, zarr.storage.StoreLike): - # if the input already a store, wrap it in an FSStore - return FsspecStore(path, **kwargs) if isinstance(path, UPath): - # if input is a remote UPath, map it to an FSStore - return FsspecStore(path.path, fs=path.fs, **kwargs) + # Check before Store to avoid UnionType isinstance (zarr's ``StoreLike`` is a type alias). + return FsspecStore(_ensure_async_fs(path.fs), path=path.path, read_only=read_only, **kwargs) + if isinstance(path, zarr.abc.store.Store): + # Already a concrete store (LocalStore, FsspecStore, MemoryStore, ...). Do not pass it as ``fs=`` to + # FsspecStore -- that only accepts an async fsspec filesystem and raises on stores (e.g. ``async_impl``). + return path raise TypeError(f"Unsupported type: {type(path)}") diff --git a/src/spatialdata/_io/io_points.py b/src/spatialdata/_io/io_points.py index 03ef3338..179a3e21 100644 --- a/src/spatialdata/_io/io_points.py +++ b/src/spatialdata/_io/io_points.py @@ -1,7 +1,5 @@ from __future__ import annotations -from pathlib import Path - import zarr from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import read_parquet @@ -13,6 +11,7 @@ overwrite_coordinate_transformations_non_raster, ) from spatialdata._io.format import CurrentPointsFormat, PointsFormats, _parse_version +from spatialdata._store import parquet_fs_and_path from spatialdata.models import get_axes_names from spatialdata.transformations._utils import ( _get_transformations, @@ -20,27 +19,27 @@ ) -def _read_points( - store: str | Path, -) -> DaskDataFrame: - """Read points from a zarr store.""" - f = zarr.open(Path(store), mode="r") # Path avoids zarr v3 URL-parsing special chars (e.g. #) in names - - version = _parse_version(f, expect_attrs_key=True) +def _read_points(group: zarr.Group) -> DaskDataFrame: + """Read a points element from an open zarr group.""" + version = _parse_version(group, expect_attrs_key=True) assert version is not None points_format = PointsFormats[version] - store_root = f.store_path.store.root - path = store_root / f.path / "points.parquet" - # cache on remote file needed for parquet reader to work - # TODO: allow reading in the metadata without caching all the data - points = read_parquet("simplecache::" + str(path) if str(path).startswith("http") else path) + fs, parquet_path = parquet_fs_and_path(group, "points.parquet") + # Use the fsspec filesystem (see parquet_fs_and_path): dask's pyarrow-FS reader would + # otherwise return known-empty categoricals and wrong per-partition lengths. The fsspec + # path returns categories as unknown (so write_points' as_known() recomputes them) and + # preserves partition boundaries (so transform() on multi-partition points works). + # TODO: allow reading in the metadata without materializing the data. + points = read_parquet(parquet_path, filesystem=fs) assert isinstance(points, DaskDataFrame) + if points.index.name == "__null_dask_index__": + points = points.rename_axis(None) - transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) + transformations = _get_transformations_from_ngff_dict(group.attrs.asdict()["coordinateTransformations"]) _set_transformations(points, transformations) - attrs = points_format.attrs_from_dict(f.attrs.asdict()) + attrs = points_format.attrs_from_dict(group.attrs.asdict()) if len(attrs): points.attrs["spatialdata_attrs"] = attrs return points @@ -69,8 +68,7 @@ def write_points( transformations = _get_transformations(points) assert transformations is not None # mypy: validate_element() in _write_element guarantees this - store_root = group.store_path.store.root - path = store_root / group.path / "points.parquet" + fs, parquet_path = parquet_fs_and_path(group, "points.parquet") # The following code iterates through all columns in the 'points' DataFrame. If the column's datatype is # 'category', it checks whether the categories of this column are known. If not, it explicitly converts the @@ -85,7 +83,7 @@ def write_points( points_without_transform = points.copy() del points_without_transform.attrs["transform"] - points_without_transform.to_parquet(path) + points_without_transform.to_parquet(parquet_path, filesystem=fs) attrs = element_format.attrs_to_dict(points.attrs) attrs["version"] = element_format.spatialdata_format_version diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 2feb7a77..fd2c3c3c 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Sequence -from pathlib import Path from typing import Any, Literal, TypeGuard, cast import dask.array as da @@ -27,6 +26,7 @@ RasterFormatType, get_ome_zarr_format, ) +from spatialdata._store import store_from_group from spatialdata._utils import get_pyramid_levels from spatialdata.models._utils import get_channel_names from spatialdata.models.models import ATTRS_KEY @@ -160,13 +160,15 @@ def _prepare_storage_options( def _read_multiscale( - store: str | Path, raster_type: Literal["image", "labels"], reader_format: Format + group: zarr.Group, raster_type: Literal["image", "labels"], reader_format: Format ) -> DataArray | DataTree: - assert isinstance(store, str | Path) assert raster_type in ["image", "labels"] + # ome_zarr.io.ZarrLocation needs a store rooted at this group's location, not at the + # SpatialData container root, so we re-root the parent store at ``group.path``. + resolved_store = store_from_group(group, read_only=True) nodes: list[Node] = [] - image_loc = ZarrLocation(store, fmt=reader_format) + image_loc = ZarrLocation(resolved_store, fmt=reader_format) if exists := image_loc.exists(): image_reader = Reader(image_loc)() image_nodes = list(image_reader) diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 3b6e18e3..e8a17f52 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,6 +1,5 @@ from __future__ import annotations -from pathlib import Path from typing import Any, Literal import numpy as np @@ -23,6 +22,7 @@ ShapesFormatV03, _parse_version, ) +from spatialdata._store import parquet_fs_and_path from spatialdata.models import ShapesModel, get_axes_names from spatialdata.transformations._utils import ( _get_transformations, @@ -30,39 +30,36 @@ ) -def _read_shapes( - store: str | Path, -) -> GeoDataFrame: - """Read shapes from a zarr store.""" - f = zarr.open(Path(store), mode="r") # Path avoids zarr v3 URL-parsing special chars (e.g. #) in names - version = _parse_version(f, expect_attrs_key=True) +def _read_shapes(group: zarr.Group) -> GeoDataFrame: + """Read a shapes element from an open zarr group.""" + version = _parse_version(group, expect_attrs_key=True) assert version is not None shape_format = ShapesFormats[version] if isinstance(shape_format, ShapesFormatV01): - coords = np.array(f["coords"]) - index = np.array(f["Index"]) - typ = shape_format.attrs_from_dict(f.attrs.asdict()) + coords = np.array(group["coords"]) + index = np.array(group["Index"]) + typ = shape_format.attrs_from_dict(group.attrs.asdict()) if typ.name == "POINT": - radius = np.array(f["radius"]) + radius = np.array(group["radius"]) geometry = from_ragged_array(typ, coords) geo_df = GeoDataFrame({"geometry": geometry, "radius": radius}, index=index) else: - offsets_keys = [k for k in f if k.startswith("offset")] + offsets_keys = [k for k in group if k.startswith("offset")] offsets_keys = natsorted(offsets_keys) - offsets = tuple(np.array(f[k]).flatten() for k in offsets_keys) + offsets = tuple(np.array(group[k]).flatten() for k in offsets_keys) geometry = from_ragged_array(typ, coords, offsets) geo_df = GeoDataFrame({"geometry": geometry}, index=index) elif isinstance(shape_format, ShapesFormatV02 | ShapesFormatV03): - store_root = f.store_path.store.root - path = Path(store_root) / f.path / "shapes.parquet" - geo_df = read_parquet(path) + fs, parquet_path = parquet_fs_and_path(group, "shapes.parquet") + with fs.open(parquet_path, "rb") as src: + geo_df = read_parquet(src) else: raise ValueError( f"Unsupported shapes format {shape_format} from version {version}. Please update the spatialdata library." ) - transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"]) + transformations = _get_transformations_from_ngff_dict(group.attrs.asdict()["coordinateTransformations"]) _set_transformations(geo_df, transformations) return geo_df @@ -168,13 +165,13 @@ def _write_shapes_v02_v03( """ from spatialdata.models._utils import TRANSFORM_KEY - store_root = group.store_path.store.root - path = store_root / group.path / "shapes.parquet" + fs, parquet_path = parquet_fs_and_path(group, "shapes.parquet") # Temporarily remove transformations from attrs to avoid serialization issues transforms = shapes.attrs[TRANSFORM_KEY] del shapes.attrs[TRANSFORM_KEY] - shapes.to_parquet(path, geometry_encoding=geometry_encoding) + with fs.open(parquet_path, "wb") as sink: + shapes.to_parquet(sink, geometry_encoding=geometry_encoding) shapes.attrs[TRANSFORM_KEY] = transforms attrs = element_format.attrs_to_dict(shapes.attrs) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 3eb4b092..4474da12 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -1,7 +1,5 @@ from __future__ import annotations -from pathlib import Path - import numpy as np import zarr from anndata import AnnData @@ -19,16 +17,16 @@ from spatialdata.models import TableModel, get_table_keys -def _read_table(store: str | Path) -> AnnData: - table = read_anndata_zarr(str(store)) +def _read_table(group: zarr.Group) -> AnnData: + """Read a table element from an open zarr group.""" + # anndata's read_zarr accepts a StoreLike; pass the group's store sub-rooted at group.path. + # The simplest portable way: pass the group itself, which anndata supports. + table = read_anndata_zarr(group) - f = zarr.open(Path(store), mode="r") # Path avoids zarr v3 URL-parsing special chars (e.g. #) in names - version = _parse_version(f, expect_attrs_key=False) + version = _parse_version(group, expect_attrs_key=False) assert version is not None table_format = TablesFormats[version] - f.store.close() - if isinstance(table_format, TablesFormatV01 | TablesFormatV02): if TableModel.ATTRS_KEY in table.uns: # fill out eventual missing attributes that has been omitted because their value was None diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 4c410fab..c10c5e47 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,13 +1,12 @@ from __future__ import annotations -import os import warnings from collections.abc import Callable from json import JSONDecodeError from pathlib import Path from typing import Any, Literal, cast -import zarr.storage +import zarr from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame @@ -27,12 +26,16 @@ from spatialdata._io.io_shapes import _read_shapes from spatialdata._io.io_table import _read_table from spatialdata._logging import logger +from spatialdata._store import ( + open_zarr_for_read, + path_from_store, + store_from_group, +) from spatialdata._types import Raster_T def _read_zarr_group_spatialdata_element( root_group: zarr.Group, - root_store_path: str, sdata_version: Literal["0.1", "0.2"], selector: set[str], read_func: Callable[..., Any], @@ -54,7 +57,6 @@ def _read_zarr_group_spatialdata_element( # skip hidden files like .zgroup or .zmetadata continue elem_group = group[subgroup_name] - elem_group_path = os.path.join(root_store_path, elem_group.path) with handle_read_errors( on_bad_files, location=f"{group.path}/{subgroup_name}", @@ -68,14 +70,24 @@ def _read_zarr_group_spatialdata_element( ), ): if element_type in ["image", "labels"]: + # Raster readers go through ome_zarr's ZarrLocation which independently + # re-resolves the element's metadata, so corruption of the element's own + # zarr.json surfaces there as a clean OSError. Pass the cached group here. reader_format = get_raster_format_for_read(elem_group, sdata_version) element = read_func( - elem_group_path, + elem_group, cast(Literal["image", "labels"], element_type), reader_format, ) elif element_type in ["shapes", "points", "tables"]: - element = read_func(elem_group_path) + # Non-raster readers consume ``group.attrs`` directly; the parent's + # consolidated-metadata cache would otherwise mask a corrupted or + # missing element-level ``zarr.json`` / ``.zattrs``. Re-open from the + # store so the corruption surfaces as OSError / JSONDecodeError. + elem_group_fresh = open_zarr_for_read( + store_from_group(elem_group, read_only=True), as_group=True + ) + element = read_func(elem_group_fresh) else: raise ValueError(f"Unknown element type {element_type}") element_container[subgroup_name] = element @@ -123,17 +135,24 @@ def get_raster_format_for_read( def read_zarr( - store: str | Path | UPath | zarr.Group, + store: str | Path | zarr.Group | zarr.storage.StoreLike, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> SpatialData: """ - Read a SpatialData dataset from a zarr store (on-disk or remote). + Read a SpatialData dataset from a zarr store (local or remote). Parameters ---------- store - Path, URL, or zarr.Group to the zarr store (on-disk or remote). + One of: + + - A local filesystem path (``str`` or :class:`pathlib.Path`) + - A zarr store (e.g. :class:`zarr.storage.LocalStore`, + :class:`zarr.storage.FsspecStore`, :class:`zarr.storage.MemoryStore`) + carrying its own filesystem (and credentials) — the supported form for + remote backends like S3 / Azure / GCS + - An already-open :class:`zarr.Group` selection List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are @@ -153,24 +172,21 @@ def read_zarr( ------- A SpatialData object. """ - from spatialdata._io._utils import _resolve_zarr_store - - resolved_store = _resolve_zarr_store(store) - root_group = zarr.open_group(resolved_store, mode="r") - # the following is the SpatialDataContainerFormat version - if "spatialdata_attrs" not in root_group.metadata.attributes: - # backward compatibility for pre-versioned SpatialData zarr stores - sdata_version: Literal["0.1", "0.2"] = "0.1" + # Coerce all input forms to (resolved_store, backing_path) where backing_path is the + # user-facing path (Path | UPath | None) for `sdata.path` and resolved_store is the + # actual zarr backend store we open the root group from. + if isinstance(store, zarr.Group): + # Already open: re-root the underlying store at this group's path. + from spatialdata._store import store_from_group + + resolved_store: Any = store_from_group(store, read_only=True) + elif isinstance(store, zarr.abc.store.Store): + resolved_store = store else: - sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] - if sdata_version == "0.1": - warnings.warn( - "SpatialData is not stored in the most current format. If you want to use Zarr v3" - ", please write the store to a new location using `sdata.write()`.", - UserWarning, - stacklevel=2, - ) - root_store_path = root_group.store.root + from spatialdata._io._utils import _resolve_zarr_store + + resolved_store = _resolve_zarr_store(store, read_only=True) + backing_path = path_from_store(resolved_store) images: dict[str, Raster_T] = {} labels: dict[str, Raster_T] = {} @@ -178,50 +194,73 @@ def read_zarr( shapes: dict[str, GeoDataFrame] = {} tables: dict[str, AnnData] = {} - selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) - logger.debug(f"Reading selection {selector}") - - # we could make this more readable. One can get lost when looking at this dict and iteration over the items - group_readers: dict[ - Literal["images", "labels", "shapes", "points", "tables"], - tuple[ - Callable[..., Any], - Literal["image", "labels", "shapes", "points", "tables"], - dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], - ], - ] = { - # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" - "images": (_read_multiscale, "image", images), - "labels": (_read_multiscale, "labels", labels), - "points": (_read_points, "points", points), - "shapes": (_read_shapes, "shapes", shapes), - "tables": (_read_table, "tables", tables), - } - for group_name, ( - read_func, - element_type, - element_container, - ) in group_readers.items(): - _read_zarr_group_spatialdata_element( - root_group=root_group, - root_store_path=root_store_path, - sdata_version=sdata_version, - selector=selector, - read_func=read_func, - group_name=group_name, - element_type=element_type, - element_container=element_container, - on_bad_files=on_bad_files, - ) - - # read attrs metadata - attrs = root_group.attrs.asdict() - if "spatialdata_attrs" in attrs: - # when refactoring the read_zarr function into reading componenets separately (and according to the version), - # we can move the code below (.pop()) into attrs_from_dict() - attrs.pop("spatialdata_attrs") - else: - attrs = None + try: + # Use the consolidated + zarr-v3-pinned fast path. See ``open_zarr_for_read`` for why + # pinning ``zarr_format=3`` matters over remote backends (avoids five small v2-metadata + # probes per open) and how the fallback keeps legacy / non-consolidated stores working. + root_group = open_zarr_for_read(resolved_store, as_group=True) + # the following is the SpatialDataContainerFormat version + if "spatialdata_attrs" not in root_group.metadata.attributes: + # backward compatibility for pre-versioned SpatialData zarr stores + sdata_version: Literal["0.1", "0.2"] = "0.1" + else: + sdata_version = root_group.metadata.attributes["spatialdata_attrs"]["version"] + if sdata_version == "0.1": + warnings.warn( + "SpatialData is not stored in the most current format. If you want to use Zarr v3" + ", please write the store to a new location using `sdata.write()`.", + UserWarning, + stacklevel=2, + ) + + selector = {"images", "labels", "points", "shapes", "tables"} if not selection else set(selection or []) + logger.debug(f"Reading selection {selector}") + + # we could make this more readable. One can get lost when looking at this dict and iteration over the items + group_readers: dict[ + Literal["images", "labels", "shapes", "points", "tables"], + tuple[ + Callable[..., Any], + Literal["image", "labels", "shapes", "points", "tables"], + dict[str, Raster_T] | dict[str, DaskDataFrame] | dict[str, GeoDataFrame] | dict[str, AnnData], + ], + ] = { + # ome-zarr-py needs a kwargs that has "image" has key. So here we have "image" and not "images" + "images": (_read_multiscale, "image", images), + "labels": (_read_multiscale, "labels", labels), + "points": (_read_points, "points", points), + "shapes": (_read_shapes, "shapes", shapes), + "tables": (_read_table, "tables", tables), + } + for group_name, ( + read_func, + element_type, + element_container, + ) in group_readers.items(): + _read_zarr_group_spatialdata_element( + root_group=root_group, + sdata_version=sdata_version, + selector=selector, + read_func=read_func, + group_name=group_name, + element_type=element_type, + element_container=element_container, + on_bad_files=on_bad_files, + ) + + # read attrs metadata + attrs = root_group.attrs.asdict() + if "spatialdata_attrs" in attrs: + # when refactoring the read_zarr function into reading componenets separately (and according to the version) + # we can move the code below (.pop()) into attrs_from_dict() + attrs.pop("spatialdata_attrs") + else: + attrs = None + finally: + # Only close stores we constructed ourselves; if the caller handed us a store or Group, + # they retain ownership. + if not isinstance(store, (zarr.Group, zarr.abc.store.Store)): + resolved_store.close() sdata = SpatialData( images=images, @@ -231,12 +270,12 @@ def read_zarr( tables=tables, attrs=attrs, ) - sdata.path = resolved_store.root + sdata._path = backing_path return sdata def _get_groups_for_element( - zarr_path: Path, element_type: str, element_name: str, use_consolidated: bool = True + zarr_path: Path | UPath, element_type: str, element_name: str, use_consolidated: bool = True ) -> tuple[zarr.Group, zarr.Group, zarr.Group]: """ Get the Zarr groups for the root, element_type and element for a specific element. @@ -265,8 +304,8 @@ def _get_groups_for_element( ------- The Zarr groups for the root, element_type and element for a specific element. """ - if not isinstance(zarr_path, Path): - raise ValueError("zarr_path should be a Path object") + if not isinstance(zarr_path, (Path, UPath)): + raise ValueError("zarr_path should be a Path or UPath object") if element_type not in [ "images", @@ -289,7 +328,7 @@ def _get_groups_for_element( return root_group, element_type_group, element_name_group -def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: str) -> bool: +def _group_for_element_exists(zarr_path: Path | UPath, element_type: str, element_name: str) -> bool: """ Check if the group for an element exists. @@ -319,14 +358,35 @@ def _group_for_element_exists(zarr_path: Path, element_type: str, element_name: return exists -def _write_consolidated_metadata(path: Path | str | None) -> None: +def _write_consolidated_metadata(path: Path | UPath | str | None) -> None: if path is not None: - f = zarr.open_group(path, mode="r+", use_consolidated=False) + if isinstance(path, UPath): + store = _resolve_zarr_store(path) + f = zarr.open_group(store, mode="r+", use_consolidated=False) + else: + f = zarr.open_group(path, mode="r+", use_consolidated=False) # .parquet files are not recognized as proper zarr and thus throw a warning. This does not affect SpatialData. # and therefore we silence it for our users as they can't do anything about this. # TODO check with remote PR whether we can prevent this warning at least for points data and whether with zarrv3 # that pr would still work. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=zarr.errors.ZarrUserWarning) + # Consolidate at the root, then at every element group + # (``/``). The per-element consolidation is what lets our readers + # -- which re-open each element via a child-rooted ``FsspecStore`` -- actually + # consume consolidated metadata at element open time. A root-only consolidation + # only benefits the first ``zarr.open_group`` call in ``read_zarr``; every + # subsequent ``zarr.open(elem_store, ...)`` rooted at the element path would + # still walk its own subtree one ``zarr.json`` at a time because the + # consolidated-metadata field lives on the *root* ``zarr.json``, not the + # child's. Consolidating per-element writes the field on every element's own + # ``zarr.json`` so a child-rooted open is a single GET regardless of depth. zarr.consolidate_metadata(f.store) + for group_name in ("images", "labels", "points", "shapes", "tables"): + if group_name not in f: + continue + for element_name in f[group_name]: + if element_name.startswith("."): + continue + zarr.consolidate_metadata(f.store, path=f"{group_name}/{element_name}") f.store.close() diff --git a/src/spatialdata/_store.py b/src/spatialdata/_store.py new file mode 100644 index 00000000..378dda9d --- /dev/null +++ b/src/spatialdata/_store.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from typing import Any, TypeAlias + +import zarr +from upath import UPath +from zarr.storage import FsspecStore, LocalStore + +PathLike: TypeAlias = Path | UPath + + +def normalize_path(path: str | PathLike, storage_options: dict[str, Any] | None = None) -> PathLike: + """Normalize a path-like input to ``Path`` (local) or ``UPath`` (remote).""" + if isinstance(path, str): + return UPath(path, **(storage_options or {})) if "://" in path else Path(path) + if isinstance(path, (Path, UPath)): + return path + raise TypeError(f"path must be a `str`, `Path`, or `UPath` object, got {type(path).__name__}.") + + +def path_from_store(store: Any) -> PathLike | None: + """Derive the user-facing path from a zarr store, or ``None`` for stores without one. + + - ``LocalStore`` → ``Path(store.root)`` + - ``FsspecStore`` → ``UPath("{protocol}://{store.path}", fs=sync_fs)`` + - Anything else (``MemoryStore``, custom stores) → ``None`` (no meaningful path) + + The fsspec branch unwraps ``AsyncFileSystemWrapper`` so the returned ``UPath`` + carries the original sync filesystem (required by pyarrow's ``FSSpecHandler``). + """ + if isinstance(store, LocalStore): + return Path(store.root) + if isinstance(store, FsspecStore): + protocol = getattr(store.fs, "protocol", None) + if isinstance(protocol, (list, tuple)): + protocol = protocol[0] if protocol else "file" + elif protocol is None: + protocol = "file" + fs = store.fs + while True: + inner = getattr(fs, "sync_fs", None) + if inner is None or inner is fs: + break + fs = inner + return UPath(f"{protocol}://{store.path}", fs=fs) + return None + + +def store_from_group(group: zarr.Group, *, read_only: bool = True) -> Any: + """Return a zarr store re-rooted at ``group.path``. + + For consumers (e.g. ``ome_zarr.io.ZarrLocation``, ``anndata.read_zarr``) that + need a store and don't understand sub-paths inside one. Falls back to returning + the parent store unchanged for stores we cannot re-root (``MemoryStore`` etc.). + """ + from spatialdata._io._utils import join_fsspec_store_path + + parent = group.store + if isinstance(parent, LocalStore): + return LocalStore(Path(parent.root) / group.path, read_only=read_only) + if isinstance(parent, FsspecStore): + return FsspecStore( + parent.fs, + path=join_fsspec_store_path(parent.path, group.path), + read_only=read_only, + ) + return parent + + +def parquet_fs_and_path(group: zarr.Group, *child_parts: str) -> tuple[Any, str]: + """Derive a (sync) fsspec filesystem + path for parquet I/O at ``group/``. + + Used by the parquet readers/writers (points, shapes) so they can locate their backing + file without going through a ``UPath``: the zarr group already carries its store (and + thus the filesystem), so we derive both directly. + + We deliberately return the *fsspec* filesystem (not a pyarrow one): dask's + ``ReadParquetPyarrowFS`` path eagerly materializes pyarrow dictionary columns into + ``known=True`` empty categoricals and reports wrong per-partition lengths, both of + which break downstream code (e.g. ``transform()`` on multi-partition points). Routing + dask/geopandas through the fsspec filesystem uses their default parquet engine, which + preserves categories-unknown and correct partition boundaries. + + The ``sync_fs`` unwrap recovers the original synchronous filesystem from zarr's + ``AsyncFileSystemWrapper`` (dask/geopandas parquet I/O is synchronous). + """ + from spatialdata._io._utils import join_fsspec_store_path + + store = group.store + child = "/".join(child_parts) + + if isinstance(store, LocalStore): + import fsspec + + local_path = Path(store.root) / group.path / child if child else Path(store.root) / group.path + return fsspec.filesystem("file"), str(local_path) + + if isinstance(store, FsspecStore): + fs = store.fs + while True: + inner = getattr(fs, "sync_fs", None) + if inner is None or inner is fs: + break + fs = inner + sub = f"{group.path}/{child}" if child else group.path + return fs, join_fsspec_store_path(store.path, sub) + + raise ValueError(f"Cannot derive a filesystem for store of type {type(store).__name__}") + + +@contextmanager +def open_read_store(path: PathLike) -> Any: + """Open *path* as a read-only zarr backend store. + + The store is constructed with ``read_only=True`` so the underlying + ``LocalStore`` / ``FsspecStore`` refuses writes at the store layer (not just + at the group ``mode="r"`` level). This also lets public HTTPS zarrs skip any + write-capability probe that fsspec may otherwise perform. + """ + from spatialdata._io._utils import _resolve_zarr_store + + resolved_store = _resolve_zarr_store(path, read_only=True) + try: + yield resolved_store + finally: + resolved_store.close() + + +def open_zarr_for_read(store: Any, *, as_group: bool = True) -> Any: + """Open a zarr group or node for reading with remote-friendly defaults. + + Prefers the fast path: pinned ``zarr_format=3`` (we only ever write v3 stores, + so skipping v2-metadata auto-probes saves up to five small GETs per open on + remote backends) and ``use_consolidated=True`` (requires the root / element + ``zarr.json`` to carry the ``consolidated_metadata`` field produced by + ``_write_consolidated_metadata``). Falls back to ``zarr.open*`` with no + format/consolidation hints for legacy or third-party stores that predate + either convention. + + Parameters + ---------- + store + A ``zarr.storage.StoreLike`` -- typically the value yielded by + ``open_read_store``. + as_group + If ``True`` (default) use ``zarr.open_group``; if ``False`` use + ``zarr.open`` which returns either a ``Group`` or an ``Array`` based on + the metadata at the store root. + """ + fn = zarr.open_group if as_group else zarr.open + try: + return fn(store, mode="r", zarr_format=3, use_consolidated=True) + except (ValueError, FileNotFoundError): + return fn(store, mode="r") + + +@contextmanager +def open_write_store(path: PathLike) -> Any: + """Open *path* as a writable zarr backend store (``read_only=False``).""" + from spatialdata._io._utils import _resolve_zarr_store + + resolved_store = _resolve_zarr_store(path, read_only=False) + try: + yield resolved_store + finally: + resolved_store.close() diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index 19c70897..5350fcb0 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -113,6 +113,8 @@ def test_set_table_nonexisting_target(self, full_sdata): def test_set_table_annotates_spatialelement(self, full_sdata, tmp_path): tmpdir = Path(tmp_path) / "tmp.zarr" del full_sdata["table"].uns[TableModel.ATTRS_KEY] + # full_sdata table has region labels2d+poly; set to labels2d only so set_table_annotates_spatialelement succeeds + full_sdata["table"].obs["region"] = pd.Categorical(["labels2d"] * full_sdata["table"].n_obs) with pytest.raises( TypeError, match="No current annotation metadata found. Please specify both region_key and instance_key." ): diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 034c01d3..fed6e74b 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -1289,6 +1289,17 @@ def test_read_sdata(tmp_path: Path, points: SpatialData) -> None: assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_zarr_group) +def test_path_setter_coerces_str_to_path_or_upath(tmp_path: Path) -> None: + """``SpatialData.path`` is stored as Path | UPath | None; strings are normalized like ``write()``.""" + sdata = SpatialData() + p = tmp_path / "store.zarr" + sdata.path = str(p) + assert isinstance(sdata.path, Path) + assert sdata.path == p + sdata.path = "s3://bucket/key.zarr" + assert isinstance(sdata.path, UPath) + + def test_sdata_with_nan_in_obs(tmp_path: Path) -> None: """Test writing SpatialData with mixed string/NaN values in obs works correctly. diff --git a/tests/io/test_store.py b/tests/io/test_store.py new file mode 100644 index 00000000..e3b910eb --- /dev/null +++ b/tests/io/test_store.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +import zarr +from upath import UPath +from zarr.storage import FsspecStore, LocalStore, MemoryStore + +from spatialdata import SpatialData +from spatialdata._io._utils import _resolve_zarr_store +from spatialdata._store import ( + normalize_path, + open_read_store, + open_write_store, + parquet_fs_and_path, + path_from_store, + store_from_group, +) +from spatialdata.testing import assert_spatial_data_objects_are_identical + + +def test_normalize_path_local_string(tmp_path: Path) -> None: + result = normalize_path(str(tmp_path / "store.zarr")) + assert isinstance(result, Path) + + +def test_normalize_path_remote_string() -> None: + result = normalize_path("s3://bucket/store.zarr") + assert isinstance(result, UPath) + + +def test_normalize_path_storage_options() -> None: + result = normalize_path("s3://bucket/store.zarr", storage_options={"anon": True}) + assert isinstance(result, UPath) + assert getattr(result.fs, "anon", None) is True + + +def test_normalize_path_passthrough_path(tmp_path: Path) -> None: + p = tmp_path / "store.zarr" + assert normalize_path(p) is p + + +def test_normalize_path_passthrough_upath() -> None: + u = UPath("s3://bucket/store.zarr") + assert normalize_path(u) is u + + +def test_open_read_and_write_store_roundtrip(tmp_path: Path) -> None: + path = tmp_path / "store.zarr" + + with open_write_store(path) as store: + group = zarr.create_group(store=store, overwrite=True) + group.attrs["answer"] = 42 + + with open_read_store(path) as store: + group = zarr.open_group(store=store, mode="r") + assert group.attrs["answer"] == 42 + + +def test_path_from_store_local(tmp_path: Path) -> None: + """Path derivation from a LocalStore returns the on-disk root as a Path.""" + path = tmp_path / "store.zarr" + + with open_write_store(path) as store: + zarr.create_group(store=store, overwrite=True) + assert path_from_store(store) == path + + +def test_store_from_group_local_subroots(tmp_path: Path) -> None: + """`store_from_group` returns a LocalStore re-rooted at the group's path.""" + path = tmp_path / "store.zarr" + + with open_write_store(path) as store: + root = zarr.create_group(store=store, overwrite=True) + group = root.require_group("images").require_group("image") + + sub = store_from_group(group, read_only=True) + assert isinstance(sub, LocalStore) + assert Path(sub.root) == path / "images" / "image" + assert sub.read_only is True + + +def test_parquet_fs_and_path_local(tmp_path: Path) -> None: + """`parquet_fs_and_path` returns an fsspec LocalFileSystem and a joined local path string.""" + from fsspec.implementations.local import LocalFileSystem + + path = tmp_path / "store.zarr" + with open_write_store(path) as store: + root = zarr.create_group(store=store, overwrite=True) + group = root.require_group("points").require_group("p1") + + fs, parquet_path = parquet_fs_and_path(group, "points.parquet") + assert isinstance(fs, LocalFileSystem) + assert parquet_path == str(path / "points" / "p1" / "points.parquet") + + +def test_resolve_zarr_store_returns_existing_zarr_stores_unchanged() -> None: + """StoreLike inputs must not be wrapped as FsspecStore(fs=store) -- that is only for async filesystems.""" + mem = MemoryStore() + assert _resolve_zarr_store(mem) is mem + loc = LocalStore(tempfile.mkdtemp()) + assert _resolve_zarr_store(loc) is loc + + +def test_resolve_zarr_store_forwards_read_only_local(tmp_path: Path) -> None: + """`_resolve_zarr_store(..., read_only=True)` must reach the LocalStore constructor.""" + store = _resolve_zarr_store(tmp_path / "store.zarr", read_only=True) + assert isinstance(store, LocalStore) + assert store.read_only is True + + +def test_resolve_zarr_store_forwards_read_only_remote() -> None: + """`_resolve_zarr_store(..., read_only=True)` must reach the FsspecStore constructor.""" + from fsspec.implementations.memory import MemoryFileSystem + + upath = UPath("memory://ro-remote.zarr", fs=MemoryFileSystem(skip_instance_cache=True)) + store = _resolve_zarr_store(upath, read_only=True) + assert isinstance(store, FsspecStore) + assert store.read_only is True + + +def test_path_from_store_remote() -> None: + """`path_from_store` on a remote FsspecStore yields a UPath with the original sync fs.""" + import fsspec + from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper + + fs = fsspec.filesystem("memory") + async_fs = AsyncFileSystemWrapper(fs, asynchronous=True) + store = FsspecStore(async_fs, path="/some/remote.zarr") + + result = path_from_store(store) + assert isinstance(result, UPath) + assert getattr(result.fs, "protocol", None) == "memory" + + +def test_path_from_store_memory_returns_none() -> None: + """Stores without a meaningful filesystem path (MemoryStore, custom) return None.""" + assert path_from_store(MemoryStore()) is None + + +# --------------------------------------------------------------------------- +# Public-API: passing a zarr StoreLike directly to write() / read_zarr(). +# This is the headline capability of the refactor -- users hand us a configured +# zarr store (e.g. FsspecStore with embedded credentials) instead of a UPath. +# --------------------------------------------------------------------------- + + +def test_write_and_read_via_local_store(points: SpatialData, tmp_path: Path) -> None: + """`write(store)` and `read_zarr(store)` round-trip through a zarr LocalStore. + + Exercises the StoreLike branches in ``write()`` (path_from_store) and + ``read_zarr()`` (use the store directly), and that ``sdata.path`` is derived + from the store. + """ + store_path = tmp_path / "store.zarr" + + write_store = LocalStore(str(store_path)) + points.write(write_store, overwrite=True) + # sdata.path is derived from the store, as a plain Path + assert points.path == store_path + + read_store = LocalStore(str(store_path), read_only=True) + read = SpatialData.read(read_store) + assert read.path == store_path + assert_spatial_data_objects_are_identical(points, read) + + +def test_write_to_memory_store_raises() -> None: + """A store with no filesystem path (MemoryStore) is rejected with a clear error.""" + sdata = SpatialData() + with pytest.raises(NotImplementedError, match="does not expose a filesystem path"): + sdata.write(MemoryStore()) diff --git a/tests/io/test_store_abstractions.py b/tests/io/test_store_abstractions.py new file mode 100644 index 00000000..c93b7191 --- /dev/null +++ b/tests/io/test_store_abstractions.py @@ -0,0 +1,296 @@ +"""Abstraction stress tests for ``SpatialData`` io against a memory-backed ``UPath``. + +These tests exercise the same read/write code paths that would be hit by a real remote +backend (S3/Azure/GCS/HTTPS), using only ``fsspec.filesystem("memory")`` and a thin +no-listing wrapper to approximate HTTP-like semantics. No emulators, no network. + +The file is deliberately scoped to the **public interface** (``SpatialData.read`` / +``SpatialData.write``) plus tamper-evident inspection of the underlying fsspec backend; +the lower-level ``ZarrStore`` / ``_resolve_zarr_store`` plumbing is unit-tested separately +in ``tests/io/test_store.py``. + +Coverage goals (generic, not provider-specific): +- ``SpatialData.read`` does not mutate backend bytes (tamper-evident snapshot equality). +- Full write / write-read-write round-trip through a remote-backed ``UPath`` for images, + labels, shapes, points, and a full sdata. The write-read-write cycle specifically pins + the categorical-schema invariant that the arrow-filesystem migration (this PR) had to + re-establish in ``_read_points``. +- Writing to a ``UPath`` lands the root metadata artifact in the backend. Reading via + consolidated metadata is left as a failing test on purpose: the invariant is stated, + but the fix (threading ``use_consolidated=True`` through ``read_zarr`` / the store + opener) is intentionally open for review discussion rather than silently suppressed. +- A ``MemoryFileSystem`` subclass that refuses listing proves that ``SpatialData.read`` + does not depend on directory listing for basic elements (the precondition for serving + public HTTPS zarrs). + +These tests are strictly stronger than moto/s3 emulator coverage: they need no external +process, no subprocess, no network, and they pin the exact abstraction boundary that the +cloud-native follow-up must not regress. +""" + +from __future__ import annotations + +import pytest +from fsspec.implementations.memory import MemoryFileSystem +from upath import UPath + +from spatialdata import SpatialData +from spatialdata.testing import assert_spatial_data_objects_are_identical + + +def _fresh_memory_upath(key: str) -> UPath: + """Build a UPath bound to a fresh (per-test) in-memory fsspec filesystem. + + ``skip_instance_cache=True`` ensures every test gets an isolated memory backend so + tests cannot leak state across each other. + """ + fs = MemoryFileSystem(skip_instance_cache=True) + return UPath(f"memory://{key}.zarr", fs=fs) + + +# --------------------------------------------------------------------------- +# SpatialData.read is side-effect-free against the backend. +# --------------------------------------------------------------------------- + + +class TestReadIsSideEffectFree: + """``SpatialData.read`` must not mutate a single byte of the backend store. + + Using a memory filesystem as a tamper-evident substrate, we snapshot every key+bytes + before and after the read and assert full equality. This is strictly a public-interface + invariant: if ``read_zarr`` (or any element reader) ever silently wrote to a remote + backend, this test is the first to catch it. The lower-level guarantee that + ``_resolve_zarr_store`` forwards ``read_only=True`` to the backend store is unit-tested + separately in ``tests/io/test_store.py``. + """ + + def test_spatialdata_read_does_not_mutate_backend(self, images: SpatialData) -> None: + upath = _fresh_memory_upath("read-only-invariant") + images.write(upath, overwrite=True) + + fs = upath.fs + + def snapshot() -> dict[str, bytes]: + return {key: fs.cat_file(key) for key in fs.find(upath.path)} + + before = snapshot() + SpatialData.read(upath) + after = snapshot() + + assert before.keys() == after.keys(), ( + f"read added/removed backend keys; added={after.keys() - before.keys()}, " + f"removed={before.keys() - after.keys()}" + ) + # Equality on bytes (not just on keys) is what makes this tamper-evident: even a + # same-size rewrite of the same key would be caught. + assert before == after, "read mutated bytes in the backend store" + + +# --------------------------------------------------------------------------- +# Full SpatialData round-trip through a memory-backed UPath: the generic +# remote-backend stress test. +# --------------------------------------------------------------------------- + + +class TestMemoryUPathRoundtrip: + """Round-trip ``SpatialData`` objects through a memory-backed ``UPath``. + + Every code path from ``normalize_path`` -> ``_resolve_zarr_store`` -> + ``open_write_store`` / ``open_read_store`` -> ``zarr.open_group(FsspecStore)`` -> + ``io_raster`` / ``io_shapes`` / ``io_points`` / ``io_table`` is exercised identically + to how it would be against S3/Azure/GCS. If any of these regresses for remote backends, + one of these tests must break. + + Note that ``overwrite=True`` is required on every ``write()`` call that targets a + ``UPath`` (per the guard in ``_validate_can_safely_write_to_path``): remote existence + checks are unreliable across fsspec backends, so the caller must explicitly opt in. + """ + + def test_remote_write_without_overwrite_raises(self, images: SpatialData) -> None: + """Writing to a remote UPath with ``overwrite=False`` is rejected (existence is unreliable).""" + upath = _fresh_memory_upath("guard") + with pytest.raises(NotImplementedError, match="requires overwrite=True"): + images.write(upath) # overwrite defaults to False + + def test_roundtrip_images_only(self, images: SpatialData) -> None: + upath = _fresh_memory_upath("images") + images.write(upath, overwrite=True) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(images, read) + + def test_roundtrip_labels_only(self, labels: SpatialData) -> None: + upath = _fresh_memory_upath("labels") + labels.write(upath, overwrite=True) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(labels, read) + + def test_roundtrip_shapes_only(self, shapes: SpatialData) -> None: + upath = _fresh_memory_upath("shapes") + shapes.write(upath, overwrite=True) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(shapes, read) + + def test_roundtrip_points_only(self, points: SpatialData) -> None: + upath = _fresh_memory_upath("points") + points.write(upath, overwrite=True) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(points, read) + + def test_write_read_write_points_preserves_categorical_schema(self, points: SpatialData) -> None: + """Regression guard for the arrow-filesystem categorical round-trip. + + This PR migrated points io to ``to_parquet`` / ``read_parquet`` with + ``filesystem=arrow_fs``. ``read_parquet(filesystem=arrow_fs)`` eagerly pandas-ifies + pyarrow dictionaries into ``CategoricalDtype`` marked ``known=True`` with an empty + category list -- that would defeat ``write_points``'s ``as_known()`` normalization + and a subsequent ``to_parquet(filesystem=arrow_fs)`` would fail with a per-partition + schema mismatch (``dictionary`` vs ``dictionary``). The + fix lives in ``_read_points`` (demote such categoricals to unknown so that + ``write_points`` recomputes categories across partitions); this test pins it. + """ + upath1 = _fresh_memory_upath("points-rt1") + upath2 = _fresh_memory_upath("points-rt2") + points.write(upath1, overwrite=True) + read = SpatialData.read(upath1) + read.write(upath2, overwrite=True) + round_tripped = SpatialData.read(upath2) + assert_spatial_data_objects_are_identical(points, round_tripped) + + def test_write_read_write_full_sdata(self, full_sdata: SpatialData) -> None: + """End-to-end guard: a full sdata round-trips write -> read -> write cleanly. + + Pinned for the same reason as the points-only variant above: the arrow-filesystem + migration in this PR had to re-establish the categorical-schema invariant on the + read side so that write does not fail on the second pass. + """ + upath1 = _fresh_memory_upath("full-rt1") + upath2 = _fresh_memory_upath("full-rt2") + full_sdata.write(upath1, overwrite=True) + read = SpatialData.read(upath1) + read.write(upath2, overwrite=True) + round_tripped = SpatialData.read(upath2) + assert_spatial_data_objects_are_identical(full_sdata, round_tripped) + + def test_roundtrip_full_sdata(self, full_sdata: SpatialData) -> None: + upath = _fresh_memory_upath("full") + full_sdata.write(upath, overwrite=True) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(full_sdata, read) + + +# --------------------------------------------------------------------------- +# Consolidated metadata on read. +# --------------------------------------------------------------------------- + + +class TestConsolidatedMetadataOnRead: + """Writing produces a consolidated-metadata artifact; the read path consumes it. + + The invariant pinned here is: for an sdata built only of elements read by our own + code (shapes / points / tables), a single ``SpatialData.read`` over a remote-backed + ``UPath`` must issue very few metadata GETs. That is what consolidated metadata buys + us: one blob at the root (and one per element group, written by + ``_write_consolidated_metadata``) replaces an O(nodes) walk of small ``zarr.json`` + / ``.zattrs`` / ``.zarray`` / ``.zgroup`` files. + + Element types backed by ``ome-zarr-py`` (images / labels) still issue many small + GETs through ``ome_zarr``'s own ZarrLocation reader, which does a v2-style + ``.zattrs`` / ``.zmetadata`` walk regardless of the v3 consolidation we write at + the root. That is an upstream concern (``ome-zarr-py`` must learn to consume + ``consolidated_metadata`` on ``zarr.json``) and is intentionally *not* covered + here; it would wrongly make this test dependent on an external package's fix. + """ + + def test_write_produces_root_metadata_on_memory_upath(self, images: SpatialData) -> None: + upath = _fresh_memory_upath("consolidated") + images.write(upath, overwrite=True) + fs = upath.fs + # The root metadata artifact differs by zarr version: zarr v3 writes ``zarr.json`` + # at every group, zarr v2 writes ``.zmetadata`` at the consolidated root. Accepting + # either keeps the test valid across versions and asserts that the write path + # actually reaches the memory backend. + root_keys = [p.rsplit("/", 1)[-1] for p in fs.find(upath.path)] + assert "zarr.json" in root_keys or ".zmetadata" in root_keys, root_keys + + def test_read_zarr_opens_via_consolidated_metadata(self, shapes: SpatialData) -> None: + # Uses the ``shapes`` fixture specifically because images/labels are read through + # ``ome_zarr.reader.ZarrLocation`` which bypasses our ``open_zarr_for_read`` and + # performs a v2-style metadata walk upstream of our code. Shapes (and points / + # tables) are read by our own readers which go through ``open_zarr_for_read`` + # -- the function under test. + upath = _fresh_memory_upath("consolidated-read") + shapes.write(upath, overwrite=True) + + # Count store GETs on the memory fs. Without consolidation + zarr_format=3 pinning, + # reading this 3-shape sdata costs ~25 small GETs (v2-metadata auto-probes + a walk + # of per-element ``zarr.json``). With both it costs ~7. We monkeypatch the public + # ``cat_file`` (the one ``MemoryFileSystem`` exposes); targeting ``_cat_file`` would + # silently miss every call. + fs = upath.fs + original_cat_file = fs.cat_file + call_count = {"n": 0} + + def counting_cat_file(path, *args, **kwargs): + call_count["n"] += 1 + return original_cat_file(path, *args, **kwargs) + + fs.cat_file = counting_cat_file + try: + SpatialData.read(upath) + finally: + fs.cat_file = original_cat_file + + # The exact bound is a documented, loose sanity check, not a micro-benchmark. + # 10 comfortably covers the observed 7 GETs for 3 shapes while staying well below + # the ~25 that an unconsolidated / v2-probing read would incur. + assert call_count["n"] < 10, f"expected consolidated metadata to reduce GETs, saw {call_count['n']}" + + +# --------------------------------------------------------------------------- +# HTTP-like read-only filesystem: simulates a remote that does not support listing. +# --------------------------------------------------------------------------- + + +class _NoListMemoryFileSystem(MemoryFileSystem): + """MemoryFileSystem that refuses directory listing, approximating HTTPS zarr semantics. + + Public HTTPS zarr reads cannot do ``ls`` / ``find`` on an arbitrary prefix; they can + only GET known keys. This wrapper fails any listing operation so we can prove that + our read path does not rely on listing -- the precondition for public HTTPS datasets + to be readable. + """ + + def _ls(self, path, detail=True, **kwargs): # type: ignore[override] + raise NotImplementedError("listing disabled to simulate HTTP-like semantics") + + def ls(self, path, detail=True, **kwargs): # type: ignore[override] + raise NotImplementedError("listing disabled to simulate HTTP-like semantics") + + def find(self, path, **kwargs): # type: ignore[override] + raise NotImplementedError("listing disabled to simulate HTTP-like semantics") + + +class TestHttpLikeReadOnlyStore: + """Approximate HTTPS zarr semantics: a read-only filesystem that refuses listing. + + The point is not to re-test zarr's FsspecStore but to catch the case where our own + ``read_zarr`` implementation (or an element reader) assumes it can list a directory. + That is exactly the pattern that breaks when pointed at a real public HTTPS zarr. + """ + + def test_read_sdata_from_no_list_fs(self, images: SpatialData, tmp_path) -> None: + # Write locally, then copy bytes into a no-list memory fs so that the backend + # resembles a public HTTPS zarr: every known key is readable but listing is disabled. + local_path = tmp_path / "local.zarr" + images.write(local_path) + + no_list_fs = _NoListMemoryFileSystem(skip_instance_cache=True) + remote_root = "no-list.zarr" + for p in local_path.rglob("*"): + if p.is_file(): + rel = p.relative_to(local_path).as_posix() + no_list_fs.pipe_file(f"{remote_root}/{rel}", p.read_bytes()) + + upath = UPath(f"memory://{remote_root}", fs=no_list_fs) + read = SpatialData.read(upath) + assert_spatial_data_objects_are_identical(images, read) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 00ca6494..b0ba4296 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -11,10 +11,31 @@ from upath import UPath from spatialdata import SpatialData, read_zarr -from spatialdata._io._utils import get_dask_backing_files, handle_read_errors +from spatialdata._io._utils import get_dask_backing_files, handle_read_errors, join_fsspec_store_path from spatialdata.models import PointsModel +@pytest.mark.parametrize( + ("store_path", "relative_path", "expected"), + [ + # empty / slash-only relative paths return the store root unchanged + ("bucket/store.zarr", "", "bucket/store.zarr"), + ("bucket/store.zarr", "/", "bucket/store.zarr"), + ("store.zarr", "", "store.zarr"), + # leading slashes on the relative path are stripped before joining + ("bucket/store.zarr", "/images/img", "bucket/store.zarr/images/img"), + ("bucket/store.zarr", "//images/img", "bucket/store.zarr/images/img"), + # nested and single-segment relative paths + ("bucket/store.zarr", "images/img", "bucket/store.zarr/images/img"), + ("bucket/store.zarr", "points", "bucket/store.zarr/points"), + ("bucket", "key", "bucket/key"), + ], +) +def test_join_fsspec_store_path(store_path: str, relative_path: str, expected: str) -> None: + """`join_fsspec_store_path` joins a relative zarr-group path onto an fsspec store root.""" + assert join_fsspec_store_path(store_path, relative_path) == expected + + def test_backing_files_points(points): """Test the ability to identify the backing files of a dask dataframe from examining its computational graph""" with tempfile.TemporaryDirectory() as tmp_dir: