diff --git a/cortex/export/panels.py b/cortex/export/panels.py index 8335a33fd..a45cb31d6 100644 --- a/cortex/export/panels.py +++ b/cortex/export/panels.py @@ -52,6 +52,8 @@ def plot_panels( interpolation: str = "nearest", layers: int = 1, headless: bool = False, + contour_overlay: Optional[Union[str, Dataview]] = None, + contour_mode: str = "contours+fill", ) -> Figure: """Plot on the same figure a number of views, as defined by a list of panel specifications. @@ -110,6 +112,15 @@ def plot_panels( Software WebGL (SwiftShader) is used, so no GPU or display server is needed. (Default: False) + contour_overlay : Dataview, str, or None + Parcellation data whose borders will be drawn as contour lines. + Can be a Vertex/Dataview (automatically bundled into a Dataset), + or a string naming a view within an existing Dataset. (Default: None) + + contour_mode : str + Contour rendering mode. Options: "contours", "contours+fill", + "colored", "colored+fill". (Default: "contours+fill") + Returns ------- fig : matplotlib.Figure @@ -150,6 +161,8 @@ def plot_panels( interpolation=interpolation, layers=layers, headless=headless, + contour_overlay=contour_overlay, + contour_mode=contour_mode, ) fig = plt.figure(figsize=figsize) diff --git a/cortex/export/save_views.py b/cortex/export/save_views.py index 2ea757510..f884b965c 100644 --- a/cortex/export/save_views.py +++ b/cortex/export/save_views.py @@ -1,7 +1,7 @@ import contextlib import os import time -from typing import Any, Mapping, Sequence, TypedDict, Union +from typing import Any, Mapping, Optional, Sequence, TypedDict, Union import cortex @@ -38,6 +38,8 @@ def save_3d_views( trim: bool = True, sleep: float = 10, headless: bool = False, + contour_overlay: Optional[Union[str, Dataview]] = None, + contour_mode: str = "contours+fill", ) -> list[str]: """Saves 3D views of `volume` under multiple specifications. @@ -47,13 +49,13 @@ def save_3d_views( Parameters ---------- - volume: pycortex.Volume or pycortex.Vertex object + volume : pycortex.Volume, pycortex.Vertex, or pycortex.Dataset Data to be displayed. - base_name: str + base_name : str Base name for images. - list_angles: list of (str or dict) + list_angles : list of (str or dict) Views to be used. Should be of length one, or of the same length as `list_surfaces`. Choices are: 'left', 'right', 'front', 'back', 'top', 'bottom', 'flatmap', @@ -61,33 +63,33 @@ def save_3d_views( or tuple of (view_name, custom dictionary of parameters). See `angle_view_params` in this file for parameter dict examples. - list_surfaces: list of (str or dict) + list_surfaces : list of (str or dict) Surfaces to be used. Should be of length one, or of the same length as `list_angles`. Choices are: 'inflated', 'flatmap', 'fiducial', 'inflated_cut', or a custom dictionary of parameters. - viewer_params: dict + viewer_params : dict Parameters passed to the viewer. - interpolation: str + interpolation : str Interpolation used to visualize the data. Possible choices are "nearest", "trilinear". (Default: "nearest"). - layers: int + layers : int Number of layers between the white and pial surfaces to average prior to plotting the data. (Default: 1). - size: tuple of int + size : tuple of int Size of produced image (before trimming). - trim: bool + trim : bool Whether to trim the white borders of the image. - sleep: float > 0 + sleep : float > 0 Time in seconds, to let the viewer open. - headless: bool + headless : bool If True, render using a headless Chromium browser via Playwright instead of requiring the user to manually open a browser window. This allows the function to run fully autonomously without any user interaction. @@ -96,14 +98,39 @@ def save_3d_views( Software WebGL (SwiftShader) is used, so no GPU or display server is needed. (Default: False) + contour_overlay : Dataview, str, or None + Parcellation data whose borders will be drawn as contour lines. + Can be a Vertex/Dataview object (automatically bundled into a Dataset + with ``volume``), or a string naming a view within an existing Dataset + passed as ``volume``. (Default: None) + + contour_mode : str + Contour rendering mode when ``contour_overlay`` is set. + Options: "contours", "contours+fill", "colored", "colored+fill". + (Default: "contours+fill") + Returns ------- - file_names: list of str + file_names : list of str Image paths. """ msg = "list_angles and list_surfaces should have the same length." assert len(list_angles) == len(list_surfaces), msg + # If contour_overlay is a Dataview, bundle volume + overlay into a Dataset. + # Preserve the original volume reference for isinstance checks below. + _contour_overlay_name = None + _original_volume = volume + if contour_overlay is not None: + if isinstance(contour_overlay, str): + _contour_overlay_name = contour_overlay + else: + # contour_overlay is a Dataview — wrap into Dataset + _contour_overlay_name = "__contour_overlay__" + volume = cortex.Dataset( + data=volume, **{_contour_overlay_name: contour_overlay} + ) + # Create viewer — use a proper context manager so that cleanup always # runs, even if an exception occurs during rendering. if headless: @@ -117,8 +144,39 @@ def save_3d_views( # Wait for the viewer to be loaded time.sleep(sleep) - # Add interpolation and layers params only if we have a volume - if isinstance(volume, (cortex.Volume, cortex.Volume2D, cortex.VolumeRGB)): + # Set up contour overlay if requested + if _contour_overlay_name is not None: + _contour_mode_map = { + "contours": 1, + "contours+fill": 2, + "colored": 3, + "colored+fill": 4, + } + if contour_mode not in _contour_mode_map: + raise ValueError( + f"Unknown contour_mode {contour_mode!r}. " + f"Valid options: {list(_contour_mode_map.keys())}" + ) + _contour_mode_int = _contour_mode_map[contour_mode] + handle._set_view( + **{ + "surface.{subject}.contours.overlay": _contour_overlay_name, + } + ) + # Wait for overlay data to load + time.sleep(sleep) + handle._set_view( + **{ + "surface.{subject}.contours.mode": _contour_mode_int, + } + ) + time.sleep(1) + + # Add interpolation and layers params only if the primary data is a volume. + # Use _original_volume (before Dataset wrapping) for the type check. + if isinstance( + _original_volume, (cortex.Volume, cortex.Volume2D, cortex.VolumeRGB) + ): interpolation_params = { "surface.{subject}.sampler": interpolation, "surface.{subject}.layers": layers, @@ -126,7 +184,13 @@ def save_3d_views( else: interpolation_params = dict() - has_flatmap = hasattr(getattr(cortex.db, volume.subject).surfaces, "flat") + # Get subject name — handle both Dataview and Dataset + if hasattr(_original_volume, "subject"): + _subject = _original_volume.subject + else: + # Dataset: get subject from first view + _subject = next(iter(volume))[1].subject + has_flatmap = hasattr(getattr(cortex.db, _subject).surfaces, "flat") file_names: list[str] = [] for view, surface in zip(list_angles, list_surfaces): if isinstance(view, str): @@ -163,7 +227,7 @@ def save_3d_views( # wait for the view to have changed for _ in range(100): for k, v in this_view_params.items(): - k = k.format(subject=volume.subject) if "{subject}" in k else k + k = k.format(subject=_subject) if "{subject}" in k else k if handle.ui.get(k)[0] != v: print("waiting for", k, handle.ui.get(k)[0], "->", v) time.sleep(0.1) diff --git a/cortex/quickflat/composite.py b/cortex/quickflat/composite.py index b12104911..880b303e0 100644 --- a/cortex/quickflat/composite.py +++ b/cortex/quickflat/composite.py @@ -3,16 +3,39 @@ from .. import dataset from ..database import db from ..options import config -from .utils import _get_height, _get_extents, _convert_svg_kwargs, _get_images, _parse_defaults -from .utils import make_flatmap_image, _make_hatch_image, _get_fig_and_ax, get_flatmask, get_flatcache +from .utils import ( + _get_height, + _get_extents, + _convert_svg_kwargs, + _get_images, + _parse_defaults, +) +from .utils import ( + make_flatmap_image, + _make_hatch_image, + _get_fig_and_ax, + get_flatmask, + get_flatcache, +) """ --- Individual compositing functions --- """ -def add_curvature(fig, dataview, extents=None, height=None, threshold=True, contrast=None, - brightness=None, smooth=None, cmap='gray', recache=False, curvature_lims=0.5, - legacy_mode=False): +def add_curvature( + fig, + dataview, + extents=None, + height=None, + threshold=True, + contrast=None, + brightness=None, + smooth=None, + cmap="gray", + recache=False, + curvature_lims=0.5, + legacy_mode=False, +): """Add curvature layer to figure Parameters @@ -22,13 +45,13 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. extents : array-like - 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to + 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to extents of images already present in figure. height : scalar - Height of image. None defaults to height of images already present in figure. + Height of image. None defaults to height of images already present in figure. threshold : boolean Whether to apply a threshold to the curvature values to create a binary curvature image - (one shade for positive curvature, one shade for negative). `None` defaults to value + (one shade for positive curvature, one shade for negative). `None` defaults to value specified in the config file contrast : float, [0-1] or None Contrast of curvature image. 1 is maximal contrast (given brightness). If brightness is 0.5 @@ -59,11 +82,12 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont """ from matplotlib.colors import Normalize + if height is None: height = _get_height(fig) # Get curvature map as image - default_smoothing = config.get('curvature', 'smooth') - if default_smoothing.lower()=='none': + default_smoothing = config.get("curvature", "smooth") + if default_smoothing.lower() == "none": default_smoothing = None else: default_smoothing = np.float_(default_smoothing) @@ -83,10 +107,16 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont norm = Normalize(vmin=-0.5, vmax=0.5) curv_im = norm(curv) # Option to use thresholded curvature - default_threshold = config.get('curvature','threshold').lower() in ('true', 't', '1', 'y', 'yes') + default_threshold = config.get("curvature", "threshold").lower() in ( + "true", + "t", + "1", + "y", + "yes", + ) use_threshold_curvature = default_threshold if threshold is None else threshold if legacy_mode and use_threshold_curvature: - curvT = (curv>0).astype(np.float32) + curvT = (curv > 0).astype(np.float32) curvT[np.isnan(curv)] = np.nan curv = curvT if isinstance(curvature_lims, (list, tuple)): @@ -102,26 +132,38 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont curv_im[np.isnan(curv)] = np.nan # Get defaults for brightness, contrast if brightness is None: - brightness = float(config.get('curvature', 'brightness')) + brightness = float(config.get("curvature", "brightness")) if contrast is None: - contrast = float(config.get('curvature', 'contrast')) + contrast = float(config.get("curvature", "contrast")) # Scale and shift curvature image curv_im = (curv_im - 0.5) * contrast + brightness if extents is None: extents = _get_extents(fig) _, ax = _get_fig_and_ax(fig) - cvimg = ax.imshow(curv_im, - aspect='equal', - extent=extents, - cmap=cmap, - vmin=0, - vmax=1, - label='curvature', - zorder=0) + cvimg = ax.imshow( + curv_im, + aspect="equal", + extent=extents, + cmap=cmap, + vmin=0, + vmax=1, + label="curvature", + zorder=0, + ) return cvimg -def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True, - sampler='nearest', recache=False, nanmean=False): + +def add_data( + fig, + braindata, + height=1024, + thick=32, + depth=0.5, + pixelwise=True, + sampler="nearest", + recache=False, + nanmean=False, +): """Add data to quickflat plot Parameters @@ -157,24 +199,44 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True, if not isinstance(dataview, dataset.Dataview): # Unclear what this means. Clarify error in terms of pycortex classes # (please provide a [cortex.dataset.Dataview or whatever] instance) - raise TypeError('Please provide a Dataview, not a Dataset') + raise TypeError("Please provide a Dataview, not a Dataset") # Generate image (2D array, maybe 3D array) - im, extents = make_flatmap_image(dataview, recache=recache, pixelwise=pixelwise, sampler=sampler, - height=height, thick=thick, depth=depth, nanmean=nanmean) + im, extents = make_flatmap_image( + dataview, + recache=recache, + pixelwise=pixelwise, + sampler=sampler, + height=height, + thick=thick, + depth=depth, + nanmean=nanmean, + ) # Check whether dataview has a cmap instance cmapdict = dataview.get_cmapdict() # Plot _, ax = _get_fig_and_ax(fig) - img = ax.imshow(im, - aspect='equal', - extent=extents, - label='data', - zorder=1, - interpolation="nearest", - **cmapdict) + img = ax.imshow( + im, + aspect="equal", + extent=extents, + label="data", + zorder=1, + interpolation="nearest", + **cmapdict, + ) return img, extents -def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_list=None, overlay_file=None, **kwargs): + +def add_rois( + fig, + dataview, + extents=None, + height=None, + with_labels=True, + roi_list=None, + overlay_file=None, + **kwargs, +): """Add ROIs layer to a figure NOTE: zorder for rois is 3 @@ -186,15 +248,15 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. extents : array-like - 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to + 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to extents of images already present in figure. - height : scalar - Height of image. None defaults to height of images already present in figure. + height : scalar + Height of image. None defaults to height of images already present in figure. with_labels : bool Whether to display text labels on ROIs - roi_list : + roi_list : - kwargs : + kwargs : Returns ------- @@ -204,23 +266,36 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis if extents is None: extents = _get_extents(fig) if height is None: - height = _get_height(fig) + height = _get_height(fig) svgobject = db.get_overlay(dataview.subject, overlay_file=overlay_file) svg_kws = _convert_svg_kwargs(kwargs) - layer_kws = _parse_defaults('rois_paths') + layer_kws = _parse_defaults("rois_paths") layer_kws.update(svg_kws) - im = svgobject.get_texture('rois', height, labels=with_labels, shape_list=roi_list, **layer_kws) + im = svgobject.get_texture( + "rois", height, labels=with_labels, shape_list=roi_list, **layer_kws + ) _, ax = _get_fig_and_ax(fig) - img = ax.imshow(im, - aspect='equal', - interpolation='bicubic', - extent=extents, - label='rois', - zorder=1000) + img = ax.imshow( + im, + aspect="equal", + interpolation="bicubic", + extent=extents, + label="rois", + zorder=1000, + ) return img -def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_list=None, overlay_file=None, **kwargs): +def add_sulci( + fig, + dataview, + extents=None, + height=None, + with_labels=True, + sulci_list=None, + overlay_file=None, + **kwargs, +): """Add sulci layer to figure Parameters @@ -230,10 +305,10 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_ dataview : cortex.Dataview object dataview containing data to be plotted, subject (surface identifier), and transform. extents : array-like - 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to + 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to extents of images already present in figure. height : scalar - Height of image. None defaults to height of images already present in figure. + Height of image. None defaults to height of images already present in figure. with_labels : bool Whether to display text labels for sulci sulci_list : list @@ -252,23 +327,35 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_ """ svgobject = db.get_overlay(dataview.subject, overlay_file=overlay_file) svg_kws = _convert_svg_kwargs(kwargs) - layer_kws = _parse_defaults('sulci_paths') + layer_kws = _parse_defaults("sulci_paths") layer_kws.update(svg_kws) - sulc = svgobject.get_texture('sulci', height, labels=with_labels, shape_list=sulci_list, **layer_kws) + sulc = svgobject.get_texture( + "sulci", height, labels=with_labels, shape_list=sulci_list, **layer_kws + ) if extents is None: extents = _get_extents(fig) _, ax = _get_fig_and_ax(fig) - img = ax.imshow(sulc, - aspect='equal', - interpolation='bicubic', - extent=extents, - label='sulci', - zorder=5) + img = ax.imshow( + sulc, + aspect="equal", + interpolation="bicubic", + extent=extents, + label="sulci", + zorder=5, + ) return img -def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, - hatch_color=(0, 0, 0), sampler='nearest', recache=False): +def add_hatch( + fig, + hatch_data, + extents=None, + height=None, + hatch_space=4, + hatch_color=(0, 0, 0), + sampler="nearest", + recache=False, +): """Add hatching to figure at locations specified in hatch_data Parameters @@ -279,20 +366,20 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, cortex.Volume object created from data scaled from 0-1; locations with values of 1 will have hatching overlaid on them in the resulting image. extents : array-like - 4 values for [Left, Right, Top, Bottom] extents of image plotted. If None, defaults to + 4 values for [Left, Right, Top, Bottom] extents of image plotted. If None, defaults to extents of images already present in figure. - height : scalar - Height of image. if None, defaults to height of images already present in figure. - hatch_space : scalar + height : scalar + Height of image. if None, defaults to height of images already present in figure. + hatch_space : scalar Spacing between hatch lines, in pixels hatch_color : 3-tuple (R, G, B) tuple for color of hatching. Values for R,G,B should be 0-1 sampler : str - Name of sampling function used to sample underlying volume data. Options include + Name of sampling function used to sample underlying volume data. Options include 'trilinear','nearest','lanczos'; see functions in cortex.mapper.samplers.py for all options recache : boolean Whether or not to recache intermediate files. Takes longer to plot this way, potentially - resolves some errors. + resolves some errors. Returns ------- @@ -307,24 +394,32 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4, extents = _get_extents(fig) if height is None: height = _get_height(fig) - hatchim = _make_hatch_image(hatch_data, height, sampler, recache=recache, - hatch_space=hatch_space) - hatchim[:,:,0] = hatch_color[0] - hatchim[:,:,1] = hatch_color[1] - hatchim[:,:,2] = hatch_color[2] + hatchim = _make_hatch_image( + hatch_data, height, sampler, recache=recache, hatch_space=hatch_space + ) + hatchim[:, :, 0] = hatch_color[0] + hatchim[:, :, 1] = hatch_color[1] + hatchim[:, :, 2] = hatch_color[2] _, ax = _get_fig_and_ax(fig) - img = ax.imshow(hatchim, - aspect="equal", - interpolation="bicubic", - extent=extents, - label='hatch', - zorder=2) + img = ax.imshow( + hatchim, + aspect="equal", + interpolation="bicubic", + extent=extents, + label="hatch", + zorder=2, + ) return img -def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0.2, 0.04), - orientation='horizontal'): +def add_colorbar( + fig, + cimg, + colorbar_ticks=None, + colorbar_location=(0.4, 0.07, 0.2, 0.04), + orientation="horizontal", +): """Add a colorbar to a flatmap plot Parameters @@ -332,12 +427,12 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0 fig : matplotlib Figure object Figure into which to insert colormap cimg : matplotlib.image.AxesImage object - Image for which to create colorbar. For reference, matplotlib.image.AxesImage + Image for which to create colorbar. For reference, matplotlib.image.AxesImage is the output of imshow() colorbar_ticks : array-like values for colorbar ticks colorbar_location : array-like - Four-long list, tuple, or array that specifies location for colorbar axes + Four-long list, tuple, or array that specifies location for colorbar axes [left, top, width, height] (?) orientation : string 'vertical' or 'horizontal' @@ -348,20 +443,25 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0 return cbar -def add_colorbar_2d(fig, cmap_name, colorbar_ticks, - colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12): +def add_colorbar_2d( + fig, + cmap_name, + colorbar_ticks, + colorbar_location=(0.425, 0.02, 0.15, 0.15), + fontsize=12, +): """Add a 2D colorbar to a flatmap plot Parameters ---------- fig : matplotlib Figure object cimg : matplotlib.image.AxesImage object - Image for which to create colorbar. For reference, matplotlib.image.AxesImage + Image for which to create colorbar. For reference, matplotlib.image.AxesImage is the output of imshow() colorbar_ticks : array-like values for colorbar ticks colorbar_location : array-like - Four-long list, tuple, or array that specifies location for colorbar axes + Four-long list, tuple, or array that specifies location for colorbar axes [left, top, width, height] (?) orientation : string 'vertical' or 'horizontal' @@ -369,11 +469,12 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks, # a bit sketchy - lazy imports import matplotlib.pyplot as plt import os - cmap_dir = config.get('webgl', 'colormaps') - cim = plt.imread(os.path.join(cmap_dir, cmap_name + '.png')) + + cmap_dir = config.get("webgl", "colormaps") + cim = plt.imread(os.path.join(cmap_dir, cmap_name + ".png")) fig, _ = _get_fig_and_ax(fig) fig.add_axes(colorbar_location) - cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation='bilinear') + cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation="bilinear") cbar.axes.set_xticks(colorbar_ticks[:2]) cbar.axes.set_xticklabels(colorbar_ticks[:2], fontdict=dict(size=fontsize)) cbar.axes.set_yticks(colorbar_ticks[2:]) @@ -381,8 +482,18 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks, return cbar -def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_labels=False, - shape_list=None, **kwargs): + +def add_custom( + fig, + dataview, + svgfile, + layer, + extents=None, + height=None, + with_labels=False, + shape_list=None, + **kwargs, +): """Add a custom data layer Parameters @@ -397,10 +508,10 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la layer : string Layer name within custom svg file to display extents : array-like - 4 values for [Left, Right, Bottom, Top] extents of image plotted. If None, defaults to + 4 values for [Left, Right, Bottom, Top] extents of image plotted. If None, defaults to extents of images already present in figure. height : scalar - Height of image. if None, defaults to height of images already present in figure. + Height of image. if None, defaults to height of images already present in figure. with_labels : bool Whether to display text labels on ROIs shape_list : list @@ -419,6 +530,7 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la """ from ..svgoverlay import get_overlay + if height is None: height = _get_height(fig) if extents is None: @@ -428,27 +540,37 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la svg_kws = _convert_svg_kwargs(kwargs) try: # Check for layer if it exists - layer_kws = _parse_defaults(layer+'_paths') + layer_kws = _parse_defaults(layer + "_paths") layer_kws.update(svg_kws) except: layer_kws = svg_kws - im = extra_svg.get_texture(layer, height, - labels=with_labels, - shape_list=shape_list, - **layer_kws) + im = extra_svg.get_texture( + layer, height, labels=with_labels, shape_list=shape_list, **layer_kws + ) _, ax = _get_fig_and_ax(fig) - img = ax.imshow(im, - aspect="equal", - interpolation="nearest", - extent=extents, - label='custom', - zorder=6) + img = ax.imshow( + im, + aspect="equal", + interpolation="nearest", + extent=extents, + label="custom", + zorder=6, + ) return img -def add_connected_vertices(fig, dataview, exclude_border_width=None, - height=None, extents=None, recache=False, - color=(1.0, 0.5, 0.1, 0.6), linewidth=0.75, - alpha=1.0, **kwargs): + +def add_connected_vertices( + fig, + dataview, + exclude_border_width=None, + height=None, + extents=None, + recache=False, + color=(1.0, 0.5, 0.1, 0.6), + linewidth=0.75, + alpha=1.0, + **kwargs, +): """Plot lines btw distant vertices that are within the same voxel Parameters @@ -492,11 +614,13 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None, if extents is None: extents = _get_extents(fig) if height is None: - height = _get_height(fig) + height = _get_height(fig) subject = dataview.subject xfmname = dataview.xfmname if xfmname is None: - raise ValueError("Dataview for add_connected_vertices must be a Volume! You seem to have provided vertex data.") + raise ValueError( + "Dataview for add_connected_vertices must be a Volume! You seem to have provided vertex data." + ) # print('computing shared voxels') shared_voxels = db.get_shared_voxels(subject, xfmname, recache=recache, **kwargs) # print('Finished computing shared voxels') @@ -506,14 +630,20 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None, if exclude_border_width: # Finding vertices that map to the border of the flatmap - img = np.nan * np.ones(mask.shape) - img[mask] = pixmap * np.arange(n_verts) # mapper.nverts + img = np.nan * np.ones(mask.shape) + img[mask] = pixmap * np.arange(n_verts) # mapper.nverts border_mask = binary_dilation(~mask, iterations=exclude_border_width) ^ (~mask) border_vertices = set(img[border_mask].astype(int)) - shared_voxels = np.array([a for a in shared_voxels if ((a[1] not in border_vertices) and (a[2] not in border_vertices))]) + shared_voxels = np.array( + [ + a + for a in shared_voxels + if ((a[1] not in border_vertices) and (a[2] not in border_vertices)) + ] + ) valid_vert_mask = np.array(pixmap.sum(0) > 0).flatten() - valid_verts = np.arange(n_verts)[valid_vert_mask] # mapper.nverts + valid_verts = np.arange(n_verts)[valid_vert_mask] # mapper.nverts # Assure both vertices in each pair are not in the medial wall vtx1valid = np.isin(shared_voxels[:, 1], valid_verts) vtx2valid = np.isin(shared_voxels[:, 2], valid_verts) @@ -532,16 +662,21 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None, # (This is the most time consuming step, as it draws many lines) # print('plotting lines...') fig, ax = _get_fig_and_ax(fig) - lc = LineCollection(pix_array_scaled, - transform=fig.transFigure, - figure=fig, - colors=color, - alpha=alpha, - linewidths=linewidth) + lc = LineCollection( + pix_array_scaled, + transform=fig.transFigure, + figure=fig, + colors=color, + alpha=alpha, + linewidths=linewidth, + ) lc_object = ax.add_collection(lc) return lc_object -def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, overlay_file=None): + +def add_cutout( + fig, name, dataview, layers=None, height=None, extents=None, overlay_file=None +): """Apply a cutout mask to extant layers in flatmap figure Parameters @@ -574,13 +709,12 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over for co_name, co_shape in svgobject.cutouts.shapes.items(): co_shape.visible = co_name == name # Get cutout image (now all white = 1, black = 0) - svg_kws = _convert_svg_kwargs(dict(fillcolor="white", - fillalpha=1.0, - linecolor="white", - linewidth=2)) - co = svgobject.get_texture('cutouts', height, labels=False, **svg_kws)[..., 0] + svg_kws = _convert_svg_kwargs( + dict(fillcolor="white", fillalpha=1.0, linecolor="white", linewidth=2) + ) + co = svgobject.get_texture("cutouts", height, labels=False, **svg_kws)[..., 0] if not np.any(co): - raise Exception('No pixels in cutout region {}!'.format(name)) + raise Exception("No pixels in cutout region {}!".format(name)) # Bounding box indices LL, RR, BB, TT = np.nan, np.nan, np.nan, np.nan @@ -588,41 +722,49 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over for layer_name, im_layer in layers.items(): im = im_layer.get_array() - # Reconcile occasional 1-pixel difference between flatmap image layers + # Reconcile occasional 1-pixel difference between flatmap image layers # that are generated by different functions if not all([np.abs(aa - bb) <= 1 for aa, bb in zip(im.shape, co.shape)]): raise Exception("Shape mismatch btw cutout and data!") - if any([np.abs(aa - bb) > 0 and np.abs(aa - bb) < 2 for aa, bb in zip(im.shape, co.shape)]): + if any( + [ + np.abs(aa - bb) > 0 and np.abs(aa - bb) < 2 + for aa, bb in zip(im.shape, co.shape) + ] + ): from scipy.misc import imresize - print('Resizing! {} to {}'.format(co.shape, im.shape[:2])) - layer_cutout = imresize(co, im.shape[:2]).astype(np.float32)/255. + + print("Resizing! {} to {}".format(co.shape, im.shape[:2])) + layer_cutout = imresize(co, im.shape[:2]).astype(np.float32) / 255.0 else: layer_cutout = copy.copy(co) # Handle different types of alpha layers. Useful for RGBVolumes if nothing else. if im.dtype == np.uint8: - im = np.cast['float32'](im)/255. - im[:,:,3] *= layer_cutout + im = np.cast["float32"](im) / 255.0 + im[:, :, 3] *= layer_cutout h, w, cdim = [float(v) for v in im.shape] else: - if np.ndim(im)==3: - im[:,:,3] *= layer_cutout + if np.ndim(im) == 3: + im[:, :, 3] *= layer_cutout h, w, cdim = [float(v) for v in im.shape] - elif np.ndim(im)==2: - im[layer_cutout==0] = np.nan + elif np.ndim(im) == 2: + im[layer_cutout == 0] = np.nan h, w = [float(v) for v in im.shape] y, x = np.nonzero(layer_cutout) l, r, b, t = extents - x_span = np.abs(r-l) - y_span = np.abs(t-b) - extents_new = [l + x.min() / w * x_span, - l + x.max() / w * x_span, - t + y.min() / h * y_span, - t + y.max() / h * y_span] + x_span = np.abs(r - l) + y_span = np.abs(t - b) + extents_new = [ + l + x.min() / w * x_span, + l + x.max() / w * x_span, + t + y.min() / h * y_span, + t + y.max() / h * y_span, + ] # Bounding box indices iy, ix = ((y.min(), y.max()), (x.min(), x.max())) - tmp = im[iy[0]:iy[1], ix[0]:ix[1]] + tmp = im[iy[0] : iy[1], ix[0] : ix[1]] im_layer.set_array(tmp) im_layer.set_extent(extents_new) @@ -641,3 +783,127 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over fig.set_size_inches(inch_size[0], inch_size[1]) return + + +def _detect_label_borders(label_img): + """Detect border pixels in a 2D label image. + + A pixel is a border if any of its 4-connected neighbors has a different + value. NaN pixels (outside brain mask) are not borders. + + Parameters + ---------- + label_img : ndarray, shape (H, W) or (H, W, C) + Label image. If 3D (e.g. RGBA), the first channel is used. + + Returns + ------- + border : ndarray, shape (H, W), dtype bool + True at border pixels. + """ + if label_img.ndim == 3: + vals = label_img[:, :, 0] + else: + vals = label_img + + valid = ~np.isnan(vals) + border = np.zeros(vals.shape, dtype=bool) + for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]: + shifted = np.roll(np.roll(vals, di, axis=0), dj, axis=1) + shifted_valid = ~np.isnan(shifted) + # Invalidate wrapped edges to avoid false borders from np.roll + if di == -1: + shifted_valid[-1, :] = False + elif di == 1: + shifted_valid[0, :] = False + if dj == -1: + shifted_valid[:, -1] = False + elif dj == 1: + shifted_valid[:, 0] = False + border |= (vals != shifted) & valid & shifted_valid + return border + + +def add_contours( + fig, + contour_data, + extents=None, + height=None, + linewidth=1, + linecolor=(0, 0, 0, 1), + sampler="nearest", + recache=False, +): + """Add contour borders of parcellation data to a quickflat plot. + + Parameters + ---------- + fig : matplotlib figure or axes + Figure or axes to plot into. + contour_data : cortex.Dataview + Parcellation/label data whose borders will be drawn. + extents : array-like, optional + [Left, Right, Top, Bottom] extents. None uses extents from existing images. + height : int, optional + Height of the flatmap image. None uses height of existing images. + linewidth : int + Width of contour lines in pixels (default: 1). + linecolor : tuple of float + (R, G, B, A) color for contour lines (default: black). + sampler : str + Sampling method (default: 'nearest' to preserve label boundaries). + recache : bool + Whether to recache intermediate files. + + Returns + ------- + img : matplotlib.image.AxesImage + Axes image object for plotted contour overlay. + """ + dataview = dataset.normalize(contour_data) + if not isinstance(dataview, dataset.Dataview): + raise TypeError("Please provide a Dataview (e.g. cortex.Vertex), not a Dataset") + + try: + if extents is None: + extents = _get_extents(fig) + except ValueError: + extents = None + if height is None: + try: + height = _get_height(fig) + except (ValueError, IndexError): + height = 1024 + + # Generate flatmap image of the label data + label_img, extents_out = make_flatmap_image( + dataview, height=height, recache=recache, sampler=sampler + ) + + if extents is None: + extents = extents_out + + # Detect borders + border = _detect_label_borders(label_img) + + # Optionally dilate for thicker lines + if linewidth > 1: + from scipy.ndimage import binary_dilation + + struct = np.ones((linewidth, linewidth)) + border = binary_dilation(border, structure=struct) + + # Create RGBA overlay image + rgba = np.zeros(label_img.shape[:2] + (4,), dtype=np.float32) + rgba[border] = linecolor + + _, ax = _get_fig_and_ax(fig) + img = ax.imshow( + rgba, + aspect="equal", + extent=extents, + interpolation="nearest", + zorder=4, + label="contours", + ) + return img diff --git a/cortex/quickflat/view.py b/cortex/quickflat/view.py index 1d7102905..0b6cdf7ea 100644 --- a/cortex/quickflat/view.py +++ b/cortex/quickflat/view.py @@ -4,7 +4,7 @@ import binascii import numpy as np import numpy.typing as npt -from typing import Optional, Union, IO +from typing import Literal, Optional, Union, IO from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -16,33 +16,68 @@ default_colorbar_locations = { - 'left': (.0, .07, .2, .04), - 'center': (.4, .07, .2, .04), - 'right': (.7, .07, .2, .04) + "left": (0.0, 0.07, 0.2, 0.04), + "center": (0.4, 0.07, 0.2, 0.04), + "right": (0.7, 0.07, 0.2, 0.04), } -def _check_colorbar_location(colorbar_location: Union[tuple[float, float, float, float], str]) -> tuple[float, float, float, float]: +def _check_colorbar_location( + colorbar_location: Union[tuple[float, float, float, float], str], +) -> tuple[float, float, float, float]: if isinstance(colorbar_location, (tuple, list)): return colorbar_location if colorbar_location not in default_colorbar_locations: - raise ValueError("colorbar_location must be one of {}".format( - list(default_colorbar_locations.keys()))) + raise ValueError( + "colorbar_location must be one of {}".format( + list(default_colorbar_locations.keys()) + ) + ) return default_colorbar_locations[colorbar_location] -def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: bool=True, thick: int=32, sampler: str='nearest', - height: int=1024, dpi: int=100, depth: float=0.5, with_rois: bool=True, with_sulci: bool=False, - with_labels: bool=True, with_colorbar: bool=True, with_borders: bool=False, - with_dropout: Union[bool, float]=False, with_curvature: bool=False, extra_disp: Optional[tuple[str, str]]=None, - with_connected_vertices: bool=False, overlay_file: Optional[str]=None, - linewidth: Optional[int]=None, linecolor: Optional[ColorType]=None, roifill: Optional[ColorType]=None, shadow: Optional[int]=None, - labelsize: Optional[str]=None, labelcolor: Optional[ColorType]=None, cutout: Optional[str]=None, curvature_brightness: Optional[float]=None, - curvature_contrast: Optional[float]=None, curvature_threshold: Optional[bool]=None, fig: Optional[Union[Figure, Axes]]=None, extra_hatch: Optional[tuple[dataset.Dataview, tuple[float, float, float]]]=None, - colorbar_ticks: Optional[npt.ArrayLike]=None, colorbar_location: Union[tuple[float, float, float, float], str]='center', roi_list: Optional[list[str]]=None, sulci_list: Optional[list[str]]=None, - nanmean: bool=False) -> Figure: +def make_figure( + braindata: dataset.Dataview, + recache: bool = False, + pixelwise: bool = True, + thick: int = 32, + sampler: str = "nearest", + height: int = 1024, + dpi: int = 100, + depth: float = 0.5, + with_rois: bool = True, + with_sulci: bool = False, + with_labels: bool = True, + with_colorbar: bool = True, + with_borders: bool = False, + with_dropout: Union[bool, float] = False, + with_curvature: bool = False, + extra_disp: Optional[tuple[str, str]] = None, + with_connected_vertices: bool = False, + overlay_file: Optional[str] = None, + linewidth: Optional[int] = None, + linecolor: Optional[ColorType] = None, + roifill: Optional[ColorType] = None, + shadow: Optional[int] = None, + labelsize: Optional[str] = None, + labelcolor: Optional[ColorType] = None, + cutout: Optional[str] = None, + curvature_brightness: Optional[float] = None, + curvature_contrast: Optional[float] = None, + curvature_threshold: Optional[bool] = None, + fig: Optional[Union[Figure, Axes]] = None, + extra_hatch: Optional[tuple[dataset.Dataview, tuple[float, float, float]]] = None, + colorbar_ticks: Optional[npt.ArrayLike] = None, + colorbar_location: Union[tuple[float, float, float, float], str] = "center", + roi_list: Optional[list[str]] = None, + sulci_list: Optional[list[str]] = None, + nanmean: bool = False, + with_contours: Union[Literal[False], dataset.Dataview] = False, + contour_linewidth: Optional[int] = None, + contour_linecolor: Optional[ColorType] = None, +) -> Figure: """Show a Volume or Vertex on a flatmap with matplotlib. Parameters @@ -123,12 +158,22 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo figure into which to plot flatmap nanmean : bool, optional (default = False) If True, NaNs in the data will be ignored when averaging across layers. + with_contours : cortex.Dataview or False, optional + Parcellation data whose label boundaries will be drawn as contour + lines on top of the plotted data. Pass a Vertex (or other Dataview) + with discrete labels. False (default) disables contours. + contour_linewidth : int, optional + Width of contour lines in pixels. None defaults to 1. + contour_linecolor : tuple of float, optional + (R, G, B, A) color for contour lines. None defaults to black. """ from matplotlib import pyplot as plt dataview = dataset.normalize(braindata) if not isinstance(dataview, dataset.Dataview): - raise TypeError('Please provide a Dataview (e.g. an instance of cortex.Volume, cortex.Vertex, etc), not a Dataset') + raise TypeError( + "Please provide a Dataview (e.g. an instance of cortex.Volume, cortex.Vertex, etc), not a Dataset" + ) if fig is None: fig_resize = True fig = plt.figure() @@ -143,20 +188,33 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo fig = ax.figure # Add data - data_im, extents = composite.add_data(ax, dataview, pixelwise=pixelwise, thick=thick, sampler=sampler, - height=height, depth=depth, recache=recache, nanmean=nanmean) + data_im, extents = composite.add_data( + ax, + dataview, + pixelwise=pixelwise, + thick=thick, + sampler=sampler, + height=height, + depth=depth, + recache=recache, + nanmean=nanmean, + ) layers = dict(data=data_im) # Add curvature if with_curvature: - curv_im = composite.add_curvature(ax, dataview, extents, - brightness=curvature_brightness, - contrast=curvature_contrast, - threshold=curvature_threshold, - curvature_lims=0.5, - legacy_mode=False, - recache=recache) - layers['curvature'] = curv_im + curv_im = composite.add_curvature( + ax, + dataview, + extents, + brightness=curvature_brightness, + contrast=curvature_contrast, + threshold=curvature_threshold, + curvature_lims=0.5, + legacy_mode=False, + recache=recache, + ) + layers["curvature"] = curv_im # Add dropout if with_dropout is not False: @@ -167,43 +225,105 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo hatch_data = None dropout_power = 20 if with_dropout is True else with_dropout if hatch_data is None: - hatch_data = utils.get_dropout(dataview.subject, dataview.xfmname, - power=dropout_power) + hatch_data = utils.get_dropout( + dataview.subject, dataview.xfmname, power=dropout_power + ) - drop_im = composite.add_hatch(ax, hatch_data, extents=extents, height=height, - sampler=sampler, recache=recache) - layers['dropout'] = drop_im + drop_im = composite.add_hatch( + ax, + hatch_data, + extents=extents, + height=height, + sampler=sampler, + recache=recache, + ) + layers["dropout"] = drop_im # Add extra hatching if extra_hatch is not None: hatch_data2, hatch_color = extra_hatch - hatch_im = composite.add_hatch(ax, hatch_data2, extents=extents, height=height, - sampler=sampler, recache=recache) - layers['hatch'] = hatch_im + hatch_im = composite.add_hatch( + ax, + hatch_data2, + extents=extents, + height=height, + sampler=sampler, + recache=recache, + ) + layers["hatch"] = hatch_im # Add rois if with_rois: - roi_im = composite.add_rois(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, - roifill=roifill, shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, - with_labels=with_labels, overlay_file=overlay_file, - roi_list=roi_list) - layers['rois'] = roi_im + roi_im = composite.add_rois( + ax, + dataview, + extents=extents, + height=height, + linewidth=linewidth, + linecolor=linecolor, + roifill=roifill, + shadow=shadow, + labelsize=labelsize, + labelcolor=labelcolor, + with_labels=with_labels, + overlay_file=overlay_file, + roi_list=roi_list, + ) + layers["rois"] = roi_im # Add sulci if with_sulci: - sulc_im = composite.add_sulci(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor, - shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, with_labels=with_labels, - overlay_file=overlay_file, sulci_list=sulci_list) - layers['sulci'] = sulc_im + sulc_im = composite.add_sulci( + ax, + dataview, + extents=extents, + height=height, + linewidth=linewidth, + linecolor=linecolor, + shadow=shadow, + labelsize=labelsize, + labelcolor=labelcolor, + with_labels=with_labels, + overlay_file=overlay_file, + sulci_list=sulci_list, + ) + layers["sulci"] = sulc_im # Add custom if extra_disp is not None: svgfile, layer = extra_disp - custom_im = composite.add_custom(ax, dataview, svgfile, layer, height=height, extents=extents, - linewidth=linewidth, linecolor=linecolor, shadow=shadow, labelsize=labelsize, - labelcolor=labelcolor, with_labels=with_labels) - layers['custom'] = custom_im + custom_im = composite.add_custom( + ax, + dataview, + svgfile, + layer, + height=height, + extents=extents, + linewidth=linewidth, + linecolor=linecolor, + shadow=shadow, + labelsize=labelsize, + labelcolor=labelcolor, + with_labels=with_labels, + ) + layers["custom"] = custom_im + # Add contours + if with_contours is not False: + contour_kw = {} + if contour_linewidth is not None: + contour_kw["linewidth"] = contour_linewidth + if contour_linecolor is not None: + contour_kw["linecolor"] = contour_linecolor + contour_im = composite.add_contours( + ax, + with_contours, + extents=extents, + height=height, + recache=recache, + **contour_kw, + ) + layers["contours"] = contour_im # Add connector lines btw connected vertices if with_connected_vertices: vertex_lines = composite.add_connected_vertices(ax, dataview, recache=recache) - ax.axis('off') + ax.axis("off") ax.set_xlim(extents[0], extents[1]) ax.set_ylim(extents[2], extents[3]) @@ -213,32 +333,44 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo # Add (apply) cutout of flatmap if cutout is not None: - extents = composite.add_cutout(ax, cutout, dataview, layers, overlay_file=overlay_file) + extents = composite.add_cutout( + ax, cutout, dataview, layers, overlay_file=overlay_file + ) if with_colorbar: colorbar_location = _check_colorbar_location(colorbar_location) # Allow 2D colorbars: if isinstance(dataview, dataset.Dataview2D): - colorbar_ticks = np.round([ - dataview.vmin, dataview.vmax, - dataview.vmin2, dataview.vmax2 - ], 2) + colorbar_ticks = np.round( + [dataview.vmin, dataview.vmax, dataview.vmin2, dataview.vmax2], 2 + ) colorbar = composite.add_colorbar_2d( - ax, dataview.cmap, colorbar_ticks, - colorbar_location=colorbar_location) + ax, dataview.cmap, colorbar_ticks, colorbar_location=colorbar_location + ) else: colorbar = composite.add_colorbar( - ax, data_im, + ax, + data_im, colorbar_location=colorbar_location, - colorbar_ticks=colorbar_ticks + colorbar_ticks=colorbar_ticks, ) # Reset axis to main figure axis plt.sca(ax) return fig -def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, recache: bool=False, pixelwise: bool=True, sampler: str='nearest', height: int=1024, - bgcolor: Optional[ColorType]=None, dpi: int=100, **kwargs) -> None: + +def make_png( + fname: Union[str, os.PathLike, IO], + braindata: dataset.Dataview, + recache: bool = False, + pixelwise: bool = True, + sampler: str = "nearest", + height: int = 1024, + bgcolor: Optional[ColorType] = None, + dpi: int = 100, + **kwargs, +) -> None: """Create a PNG of the VertexData or VolumeData on a flatmap. Parameters @@ -289,12 +421,15 @@ def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, re (R, G, B, A) specification for the label color """ from matplotlib import pyplot as plt - fig = make_figure(braindata, - recache=recache, - pixelwise=pixelwise, - sampler=sampler, - height=height, - **kwargs) + + fig = make_figure( + braindata, + recache=recache, + pixelwise=pixelwise, + sampler=sampler, + height=height, + **kwargs, + ) imsize = fig.get_axes()[0].get_images()[0].get_size() fig.set_size_inches(np.array(imsize)[::-1] / float(dpi)) @@ -305,8 +440,18 @@ def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, re fig.clf() plt.close(fig) -def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=['rois'], - height=1024, overlay_file=None, with_dropout=False, **kwargs): + +def make_svg( + fname, + braindata, + with_labels=False, + with_curvature=True, + layers=["rois"], + height=1024, + overlay_file=None, + with_dropout=False, + **kwargs, +): """Save an svg file of the desired flatmap. This function creates an SVG file with vector graphic ROIs overlaid on a single png image. @@ -344,9 +489,9 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=[' arr, extents = make_flatmap_image(braindata, height=height, **kwargs) # Set nans to alpha = 0. to enable transparency when saving as PNG mask_nans = np.isnan(arr[..., 3]) - arr[mask_nans, 3] = 0. + arr[mask_nans, 3] = 0.0 - if hasattr(braindata, 'cmap'): + if hasattr(braindata, "cmap"): imsave(fp, arr, cmap=braindata.cmap, vmin=braindata.vmin, vmax=braindata.vmax) else: imsave(fp, arr) @@ -357,6 +502,7 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=[' if with_curvature: # no options. learn to love it. from cortex import db + fpc = io.BytesIO() curv_vertices = db.get_surfinfo(braindata.subject) curv_arr, _ = make_flatmap_image(curv_vertices, height=height) @@ -364,7 +510,7 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=[' curv_arr = np.where(curv_arr > 0, 0.5, 0.25) curv_arr[mask] = np.nan - imsave(fpc, curv_arr, cmap='Greys_r', vmin=0, vmax=1) + imsave(fpc, curv_arr, cmap="Greys_r", vmin=0, vmax=1) fpc.seek(0) image_data = [binascii.b2a_base64(fpc.read()), pngdata] @@ -425,16 +571,16 @@ def make_gif(output_destination, volumes, frame_duration=1, **figure_kwargs): fig = plt.figure(figsize=(12, 6), dpi=100) _ = make_figure(volumes[name], fig=fig, **figure_kwargs) _ = fig.suptitle(name) - path = os.path.join(tmpdir.name, str(i) + '.png') + path = os.path.join(tmpdir.name, str(i) + ".png") fig.savefig(path) images.append(imageio.imread(path)) _ = plt.close(fig) tmpdir.cleanup() - imageio.mimsave(output_destination, images, format='gif', duration=frame_duration) + imageio.mimsave(output_destination, images, format="gif", duration=frame_duration) - if hasattr(output_destination, 'seek'): + if hasattr(output_destination, "seek"): output_destination.seek(0) @@ -442,9 +588,25 @@ def show(*args, **kwargs): """Wrapper for make_figure()""" return make_figure(*args, **kwargs) -def make_movie(name, data, subject, xfmname, recache=False, height=1024, - sampler='nearest', dpi=100, tr=2, interp='linear', fps=30, - vcodec='libtheora', bitrate="8000k", vmin=None, vmax=None, **kwargs): + +def make_movie( + name, + data, + subject, + xfmname, + recache=False, + height=1024, + sampler="nearest", + dpi=100, + tr=2, + interp="linear", + fps=30, + vcodec="libtheora", + bitrate="8000k", + vmin=None, + vmax=None, + **kwargs, +): """Create a movie of an 4D data set""" raise NotImplementedError import sys @@ -457,7 +619,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024, from scipy.interpolate import interp1d # Make the flatmaps - ims, extents = make_flatmap_image(data, subject, xfmname, recache=recache, height=height, sampler=sampler) + ims, extents = make_flatmap_image( + data, subject, xfmname, recache=recache, height=height, sampler=sampler + ) if vmin is None: vmin = np.nanmin(ims) if vmax is None: @@ -469,9 +633,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024, img = fig.axes[0].images[0] # Set up interpolation - times = np.arange(0, len(ims)*tr, tr) + times = np.arange(0, len(ims) * tr, tr) interp = interp1d(times, ims, kind=interp, axis=0, copy=False) - frames = np.linspace(0, times[-1], (len(times)-1)*tr*fps+1) + frames = np.linspace(0, times[-1], (len(times) - 1) * tr * fps + 1) try: path = tempfile.mkdtemp() @@ -481,7 +645,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024, fig.savefig(impath.format(frame), transparent=True, dpi=dpi) # avconv might not be relevant function for all operating systems. # Introduce operating system check here? - cmd = "avconv -i {path} -vcodec {vcodec} -r {fps} -b {br} {name}".format(path=impath, vcodec=vcodec, fps=fps, br=bitrate, name=name) + cmd = "avconv -i {path} -vcodec {vcodec} -r {fps} -b {br} {name}".format( + path=impath, vcodec=vcodec, fps=fps, br=bitrate, name=name + ) sp.call(shlex.split(cmd)) finally: shutil.rmtree(path) diff --git a/cortex/tests/test_contours.py b/cortex/tests/test_contours.py new file mode 100644 index 000000000..80bcd4ddb --- /dev/null +++ b/cortex/tests/test_contours.py @@ -0,0 +1,301 @@ +"""Tests for contour/border rendering of parcellation data. + +Tests cover: +- Python utility: get_contour_vertices() +- Quickflat: add_contours(), _detect_label_borders(), make_figure(with_contours=...) +- WebGL: shader contour uniforms and geometry attributes +""" + +import numpy as np +import pytest + +import cortex +from cortex.quickflat.composite import _detect_label_borders +from cortex.testing_utils import has_installed + +no_inkscape = not has_installed("inkscape") + +SUBJECT = "S1" + + +def _make_parcellation(subject=SUBJECT): + """Create parcellation vertex data from existing ROIs. + + Returns array of shape (n_vertices,) with integer labels per ROI + and 0 for vertices not in any ROI. + """ + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + roi_verts = cortex.get_roi_verts(subject) + n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0] + parcellation = np.zeros(n_verts, dtype=float) + for i, (name, verts) in enumerate(roi_verts.items(), start=1): + parcellation[np.asarray(verts, dtype=int)] = float(i) + return parcellation + + +# --- Tests for Python contour utility --- + + +class TestGetContourVertices: + def test_returns_border_vertices(self): + """get_contour_vertices should return True at vertices bordering + different label values.""" + parcellation = _make_parcellation() + border = cortex.utils.get_contour_vertices(parcellation, SUBJECT) + assert border.dtype == bool + assert border.shape == parcellation.shape + # There should be some border vertices (parcellation has multiple labels) + assert border.sum() > 0 + # Border vertices should be fewer than total labeled vertices + assert border.sum() < (parcellation > 0).sum() + + def test_uniform_data_has_no_borders(self): + """Uniform data (all same label) should have no border vertices.""" + n_verts = cortex.db.get_surf(SUBJECT, "fiducial", merge=True)[0].shape[0] + uniform = np.ones(n_verts, dtype=float) + border = cortex.utils.get_contour_vertices(uniform, SUBJECT) + assert border.sum() == 0 + + def test_border_vertices_are_adjacent_to_different_labels(self): + """Every border vertex should have at least one neighbor with a + different label.""" + parcellation = _make_parcellation() + border = cortex.utils.get_contour_vertices(parcellation, SUBJECT) + + _, polys = cortex.db.get_surf(SUBJECT, "fiducial", merge=True) + neighbors = cortex.utils._get_neighbors_dict(polys) + + # Check a sample of border vertices + border_verts = np.where(border)[0][:100] + for v in border_verts: + neighbor_labels = { + parcellation[n] for n in neighbors[v] if n < len(parcellation) + } + assert ( + len(neighbor_labels) > 1 or parcellation[v] not in neighbor_labels + ), f"Border vertex {v} has no neighbor with different label" + + +# --- Tests for _detect_label_borders (2D image helper) --- + + +class TestDetectLabelBorders: + def test_uniform_image_no_borders(self): + """Uniform label image should have no borders.""" + img = np.ones((50, 50)) + border = _detect_label_borders(img) + assert border.sum() == 0 + + def test_two_regions_has_border(self): + """Image split into two regions should have a border between them.""" + img = np.ones((50, 50)) + img[:, 25:] = 2.0 + border = _detect_label_borders(img) + # Border should be at column 24 and 25 (the boundary pixels) + assert border.sum() > 0 + # Border pixels should be in the middle columns + border_cols = np.where(border.any(axis=0))[0] + assert 24 in border_cols or 25 in border_cols + + def test_nan_pixels_are_not_borders(self): + """NaN pixels (outside brain mask) should not be marked as borders.""" + img = np.full((50, 50), np.nan) + img[10:40, 10:40] = 1.0 + img[10:40, 25:40] = 2.0 + border = _detect_label_borders(img) + # Should have borders between label 1 and 2, but not at NaN edges + # (NaN-to-value transitions should not count as borders since + # we're interested in parcel-to-parcel boundaries, not brain mask edges) + nan_mask = np.isnan(img) + assert not border[nan_mask].any() + + def test_3d_image_uses_first_channel(self): + """For RGBA images, borders should be detected on first channel.""" + img = np.ones((50, 50, 4)) + img[:, 25:, 0] = 2.0 + border = _detect_label_borders(img) + assert border.sum() > 0 + + +# --- Tests for quickflat contour rendering --- + + +class TestQuickflatContours: + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_add_contours_returns_image(self): + """add_contours() should return a matplotlib AxesImage.""" + from matplotlib import pyplot as plt + from cortex.quickflat.composite import add_contours + + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + + fig, ax = plt.subplots() + img = add_contours(ax, parc_vertex, height=256) + assert img is not None + plt.close(fig) + + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_add_contours_border_pixels_are_opaque(self): + """Contour overlay should have opaque pixels only at label borders.""" + from cortex.quickflat.composite import add_contours + from matplotlib import pyplot as plt + + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + + fig, ax = plt.subplots() + img = add_contours(ax, parc_vertex, height=256) + rgba = img.get_array() + # Alpha channel should be > 0 only at border pixels + has_content = rgba[:, :, 3] > 0 if rgba.ndim == 3 else rgba > 0 + assert has_content.any(), "No contour pixels found" + plt.close(fig) + + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_make_figure_with_contours(self): + """make_figure() with with_contours should produce a figure with + contour overlay.""" + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + activation = cortex.Vertex( + np.random.randn(parcellation.shape[0]), SUBJECT, cmap="hot", vmin=-2, vmax=2 + ) + + fig = cortex.quickflat.make_figure( + activation, + with_contours=parc_vertex, + with_rois=False, + with_colorbar=False, + height=256, + ) + assert fig is not None + # Should have at least 2 images: data + contours + ax = fig.get_axes()[0] + images = ax.get_images() + assert len(images) >= 2 + + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_make_figure_volume_with_vertex_contours(self): + """make_figure() should work with Volume data + Vertex contour overlay.""" + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + volume = cortex.Volume.random(subject=SUBJECT, xfmname="fullhead") + + fig = cortex.quickflat.make_figure( + volume, + with_contours=parc_vertex, + with_rois=False, + with_colorbar=False, + height=256, + ) + assert fig is not None + ax = fig.get_axes()[0] + images = ax.get_images() + assert len(images) >= 2 + + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_contour_linewidth(self): + """Thicker linewidth should produce more border pixels.""" + from cortex.quickflat.composite import add_contours + from matplotlib import pyplot as plt + + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + + fig1, ax1 = plt.subplots() + img1 = add_contours(ax1, parc_vertex, height=256, linewidth=1) + rgba1 = img1.get_array() + n_pixels_1 = ( + (rgba1[:, :, 3] > 0).sum() if rgba1.ndim == 3 else (rgba1 > 0).sum() + ) + + fig2, ax2 = plt.subplots() + img2 = add_contours(ax2, parc_vertex, height=256, linewidth=3) + rgba2 = img2.get_array() + n_pixels_3 = ( + (rgba2[:, :, 3] > 0).sum() if rgba2.ndim == 3 else (rgba2 > 0).sum() + ) + + assert ( + n_pixels_3 > n_pixels_1 + ), f"linewidth=3 ({n_pixels_3}) should have more pixels than linewidth=1 ({n_pixels_1})" + plt.close("all") + + @pytest.mark.skipif(no_inkscape, reason="Inkscape required") + def test_contour_linecolor(self): + """Custom linecolor should appear in the contour image.""" + from cortex.quickflat.composite import add_contours + from matplotlib import pyplot as plt + + parcellation = _make_parcellation() + parc_vertex = cortex.Vertex(parcellation, SUBJECT) + + red = (1.0, 0.0, 0.0, 1.0) + fig, ax = plt.subplots() + img = add_contours(ax, parc_vertex, height=256, linecolor=red) + rgba = img.get_array() + if rgba.ndim == 3: + border_mask = rgba[:, :, 3] > 0 + # Red channel should be 1.0 at border pixels + np.testing.assert_allclose(rgba[border_mask, 0], 1.0) + # Green and blue should be 0 + np.testing.assert_allclose(rgba[border_mask, 1], 0.0) + np.testing.assert_allclose(rgba[border_mask, 2], 0.0) + plt.close(fig) + + +# --- Tests for WebGL contour support --- + + +class TestWebGLContours: + def test_shader_includes_contour_uniforms(self): + """Both surface_vertex and surface_pixel shaders should include + contour-related uniforms.""" + import os + + shader_path = os.path.join( + os.path.dirname(cortex.__file__), "webgl", "resources", "js", "shaderlib.js" + ) + with open(shader_path, "r") as f: + shader_code = f.read() + + assert "contourMode" in shader_code + assert "contourThreshold" in shader_code + assert "contourColor" in shader_code + assert "fwidth" in shader_code + assert "vDataValue" in shader_code + assert "vContourDataValue" in shader_code + assert "contourColormap" in shader_code + + def test_geometry_has_contour_attributes(self): + """mriview_surface.js should initialize contourData attributes.""" + import os + + surface_path = os.path.join( + os.path.dirname(cortex.__file__), + "webgl", + "resources", + "js", + "mriview_surface.js", + ) + with open(surface_path, "r") as f: + surface_code = f.read() + + assert "contourData0" in surface_code + assert "contourData1" in surface_code + + def test_viewer_has_contour_overlay_support(self): + """mriview.js should have contour overlay selection support.""" + import os + + viewer_path = os.path.join( + os.path.dirname(cortex.__file__), "webgl", "resources", "js", "mriview.js" + ) + with open(viewer_path, "r") as f: + viewer_code = f.read() + + assert "setContourOverlay" in viewer_code or "contour_overlay" in viewer_code diff --git a/cortex/utils.py b/cortex/utils.py index 8cfaf201d..40d9e1bba 100644 --- a/cortex/utils.py +++ b/cortex/utils.py @@ -1,5 +1,5 @@ -"""Contain utility functions -""" +"""Contain utility functions""" + import binascii import copy from importlib import import_module @@ -12,7 +12,19 @@ import urllib.request import warnings -from typing import Any, Callable, Generic, Optional, TypeVar, TYPE_CHECKING, Union, cast, overload, Literal +from typing import ( + Any, + Callable, + Generic, + Optional, + TypeVar, + TYPE_CHECKING, + Union, + cast, + overload, + Literal, +) + if sys.version_info < (3, 10): from typing_extensions import ParamSpec else: @@ -33,47 +45,67 @@ # register_cmap is deprecated in matplotlib > 3.7.0 and replaced by colormaps.register try: from matplotlib import colormaps as cm + def register_cmap(cmap): return cm.register(cmap) except ImportError: from matplotlib.cm import register_cmap -P = ParamSpec('P') -T = TypeVar('T') +P = ParamSpec("P") +T = TypeVar("T") + class DocLoader(Generic[P, T]): - def __init__(self, func, mod, package, actual_func: Optional[Callable[P, T]] = None): - self._load: Callable[[], Callable[P, T]] = lambda: getattr(import_module(mod, package), func) - self._actual_func = actual_func # stored only to resolve generic types during type checking + def __init__( + self, func, mod, package, actual_func: Optional[Callable[P, T]] = None + ): + self._load: Callable[[], Callable[P, T]] = lambda: getattr( + import_module(mod, package), func + ) + self._actual_func = ( + actual_func # stored only to resolve generic types during type checking + ) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self._load()(*args, **kwargs) @overload - def __getattribute__(self, name: Literal['_load']) -> Callable[P, T]: ... + def __getattribute__(self, name: Literal["_load"]) -> Callable[P, T]: ... @overload def __getattribute__(self, name: str) -> Any: ... - def __getattribute__(self, name: Union[Literal['_load'], str]) -> Union[Any, Callable[P, T]]: + def __getattribute__( + self, name: Union[Literal["_load"], str] + ) -> Union[Any, Callable[P, T]]: if name != "_load": return getattr(self._load(), name) else: return cast(Callable[P, T], object.__getattribute__(self, name)) + if TYPE_CHECKING: from cortex.mapper import get_mapper as _get_mapper else: _get_mapper = None get_mapper = DocLoader("get_mapper", ".mapper", "cortex", actual_func=_get_mapper) + def get_roipack(*args, **kwargs): - warnings.warn('Please use db.get_overlay instead', DeprecationWarning) + warnings.warn("Please use db.get_overlay instead", DeprecationWarning) return db.get_overlay(*args, **kwargs) -def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=False, - decimate=False, external_svg=None, - overlays_available=None): + +def get_ctmpack( + subject, + types=("inflated",), + method="raw", + level=0, + recache=False, + decimate=False, + external_svg=None, + overlays_available=None, +): """Creates ctm file for the specified input arguments. This is a cached file that specifies (1) the surfaces between which @@ -108,12 +140,10 @@ def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=Fal ------- ctmfile : """ - lvlstr = ("%dd" if decimate else "%d")%level + lvlstr = ("%dd" if decimate else "%d") % level # Generates different cache files for each combination of disp_layers - ctmcache = "%s_[{types}]_{method}_{level}_v3.json"%subject - ctmcache = ctmcache.format(types=','.join(types), - method=method, - level=lvlstr) + ctmcache = "%s_[{types}]_{method}_{level}_v3.json" % subject + ctmcache = ctmcache.format(types=",".join(types), method=method, level=lvlstr) ctmfile = os.path.join(db.get_cache(subject), ctmcache) if os.path.exists(ctmfile) and not recache: @@ -121,20 +151,23 @@ def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=Fal print("Generating new ctm file...") from . import brainctm - ptmap = brainctm.make_pack(ctmfile, - subject, - types=types, - method=method, - level=level, - decimate=decimate, - external_svg=external_svg, - overlays_available=overlays_available) + + ptmap = brainctm.make_pack( + ctmfile, + subject, + types=types, + method=method, + level=level, + decimate=decimate, + external_svg=external_svg, + overlays_available=overlays_available, + ) return ctmfile def get_ctmmap(subject, **kwargs): """Return a mapping from the vertices in the CTM surface to the vertices - in the freesurfer surface. + in the freesurfer surface. The mapping is a numpy array, such that `ctm2fs_left[i] = j` means that the i-th vertex in the CTM surface corresponds to the j-th vertex in the freesurfer surface. @@ -162,8 +195,9 @@ def get_ctmmap(subject, **kwargs): from scipy.spatial import cKDTree from . import brainctm + jsfile = get_ctmpack(subject, **kwargs) - ctmfile = os.path.splitext(jsfile)[0]+".ctm" + ctmfile = os.path.splitext(jsfile)[0] + ".ctm" # Load freesurfer surfaces try: @@ -209,6 +243,7 @@ def get_ctm2webgl_map(subject, **kwargs): maximum length of 65535. """ from . import brainctm + # Load CTM surfaces jsonfile = get_ctmpack(subject, **kwargs) ctmfile = os.path.splitext(jsonfile)[0] + ".ctm" @@ -290,7 +325,7 @@ def get_fs2webgl_map(subject, **kwargs): return fs2webgl_left, fs2webgl_right -def get_cortical_mask(subject, xfmname, type='nearest'): +def get_cortical_mask(subject, xfmname, type="nearest"): """Gets the cortical mask for a particular transform Parameters @@ -301,12 +336,12 @@ def get_cortical_mask(subject, xfmname, type='nearest'): Transform name type : str Mask type, one of {"cortical", "thin", "thick", "nearest", "line_nearest"}. - - 'cortical' includes voxels contained within the cortical ribbon, - between the freesurfer-estimated white matter and pial surfaces. - - 'thin' includes voxels that are < 2mm away from the fiducial surface. + - 'cortical' includes voxels contained within the cortical ribbon, + between the freesurfer-estimated white matter and pial surfaces. + - 'thin' includes voxels that are < 2mm away from the fiducial surface. - 'thick' includes voxels that are < 8mm away from the fiducial surface. - 'nearest' includes only the voxels overlapping the fiducial surface. - - 'line_nearest' includes all voxels that have any part within the cortical + - 'line_nearest' includes all voxels that have any part within the cortical ribbon. Returns @@ -316,13 +351,13 @@ def get_cortical_mask(subject, xfmname, type='nearest'): Notes ----- - "nearest" is a conservative "cortical" mask, while "line_nearest" is a liberal + "nearest" is a conservative "cortical" mask, while "line_nearest" is a liberal "cortical" mask. """ - if type == 'cortical': + if type == "cortical": ppts, polys = db.get_surf(subject, "pia", merge=True, nudge=False) wpts, polys = db.get_surf(subject, "wm", merge=True, nudge=False) - thickness = np.sqrt(((ppts - wpts)**2).sum(1)) + thickness = np.sqrt(((ppts - wpts) ** 2).sum(1)) dist, idx = get_vox_dist(subject, xfmname) cortex = np.zeros(dist.shape, dtype=bool) @@ -331,9 +366,9 @@ def get_cortical_mask(subject, xfmname, type='nearest'): mask = idx == vert cortex[mask] = dist[mask] <= thickness[vert] if i % 100 == 0: - print("%0.3f%%"%(i/float(len(verts)) * 100)) + print("%0.3f%%" % (i / float(len(verts)) * 100)) return cortex - elif type in ('thick', 'thin'): + elif type in ("thick", "thin"): dist, idx = get_vox_dist(subject, xfmname) return dist < dict(thick=8, thin=2)[type] else: @@ -379,8 +414,8 @@ def get_vox_dist(subject, xfmname, surface="fiducial", max_dist=np.inf): return dist.T, argdist.T -def get_hemi_masks(subject, xfmname, type='nearest'): - '''Returns a binary mask of the left and right hemisphere +def get_hemi_masks(subject, xfmname, type="nearest"): + """Returns a binary mask of the left and right hemisphere surface voxels for the given subject. Parameters @@ -394,11 +429,13 @@ def get_hemi_masks(subject, xfmname, type='nearest'): Returns ------- - ''' + """ return get_mapper(subject, xfmname, type=type).hemimasks -def add_roi(data, name="new_roi", open_inkscape=True, add_path=True, - overlay_file=None, **kwargs): + +def add_roi( + data, name="new_roi", open_inkscape=True, add_path=True, overlay_file=None, **kwargs +): """Add new flatmap image to the ROI file for a subject. (The subject is specified in creation of the data object) @@ -440,14 +477,16 @@ def add_roi(data, name="new_roi", open_inkscape=True, add_path=True, svg = db.get_overlay(dv.subject, overlay_file=overlay_file) fp = io.BytesIO() - quickflat.make_png(fp, dv, height=1024, with_rois=False, with_labels=False, **kwargs) + quickflat.make_png( + fp, dv, height=1024, with_rois=False, with_labels=False, **kwargs + ) fp.seek(0) - svg.rois.add_shape(name, binascii.b2a_base64(fp.read()).decode('utf-8'), add_path) + svg.rois.add_shape(name, binascii.b2a_base64(fp.read()).decode("utf-8"), add_path) if open_inkscape: - inkscape_cmd = config.get('dependency_paths', 'inkscape') - if LooseVersion(INKSCAPE_VERSION) < LooseVersion('1.0'): - cmd = [inkscape_cmd, '-f', svg.svgfile] + inkscape_cmd = config.get("dependency_paths", "inkscape") + if LooseVersion(INKSCAPE_VERSION) < LooseVersion("1.0"): + cmd = [inkscape_cmd, "-f", svg.svgfile] else: cmd = [inkscape_cmd, svg.svgfile] return sp.call(cmd) @@ -463,6 +502,41 @@ def _get_neighbors_dict(polys): return neighbors_dict +def get_contour_vertices(data, subject, surface="fiducial"): + """Find vertices at borders of parcellation labels. + + A vertex is a border vertex if any of its mesh neighbors has a different + label value. This is useful for drawing contour lines around parcellation + regions on the cortical surface. + + Parameters + ---------- + data : array_like, shape (n_vertices,) + Label values per vertex (e.g., parcellation integers). + subject : str + Subject name in the pycortex database. + surface : str + Surface type for adjacency computation (default: 'fiducial'). + + Returns + ------- + border_mask : ndarray, shape (n_vertices,), dtype bool + True at border vertices, False elsewhere. + """ + _, polys = db.get_surf(subject, surface, merge=True) + neighbors = _get_neighbors_dict(polys) + data = np.asarray(data) + border = np.zeros(len(data), dtype=bool) + for v, neighs in neighbors.items(): + if v >= len(data): + continue + for n in neighs: + if n < len(data) and data[v] != data[n]: + border[v] = True + break + return border + + def get_roi_verts(subject, roi=None, mask=False, overlay_file=None): """Return vertices for the given ROIs, or all ROIs if none are given. @@ -509,8 +583,8 @@ def get_roi_verts(subject, roi=None, mask=False, overlay_file=None): for name in roi: roi_idx = np.intersect1d(svg.rois.get_mask(name), goodpts) - # Now we want to include also the vertices that were removed from the flat - # surface that is, for every vertex in roi_idx we want to add the pts that are + # Now we want to include also the vertices that were removed from the flat + # surface that is, for every vertex in roi_idx we want to add the pts that are # not in goodpts but that are in pts_full # to do that, we need to find the neighboring indices from polys_full extra_idx = set() @@ -568,7 +642,7 @@ def get_roi_surf(subject, surf_type, roi, overlay_file=None): return pts[vert_idx], np.array(reindexed_polys) -def get_roi_mask(subject, xfmname, roi=None, projection='nearest'): +def get_roi_mask(subject, xfmname, roi=None, projection="nearest"): """Return a mask for the given ROI(s) Deprecated - use get_roi_masks() @@ -588,15 +662,15 @@ def get_roi_mask(subject, xfmname, roi=None, projection='nearest'): output : dict Dict of ROIs and their masks """ - warnings.warn('Deprecated! Use get_roi_masks') + warnings.warn("Deprecated! Use get_roi_masks") mapper = get_mapper(subject, xfmname, type=projection) rois = get_roi_verts(subject, roi=roi, mask=True) output = dict() for name, verts in list(rois.items()): # This is broken; unclear when/if backward mappers ever worked this way. - #left, right = mapper.backwards(vert_mask) - #output[name] = left + right + # left, right = mapper.backwards(vert_mask) + # output[name] = left + right output[name] = mapper.backwards(verts.astype(float)) # Threshold? return output @@ -644,6 +718,7 @@ def get_aseg_mask(subject, aseg_name, xfmname=None, order=1, threshold=None, **k """ from .freesurfer import fs_aseg_dict + aseg = db.get_anat(subject, type="aseg").get_fdata().T if not isinstance(aseg_name, (list, tuple)): @@ -652,24 +727,36 @@ def get_aseg_mask(subject, aseg_name, xfmname=None, order=1, threshold=None, **k mask = np.zeros(aseg.shape) for name in aseg_name: if name in fs_aseg_dict: - tmp = aseg==fs_aseg_dict[name] + tmp = aseg == fs_aseg_dict[name] else: # Combine all masks containing `name` (e.g. all masks with 'cerebellum' in the name) keys = [k for k in fs_aseg_dict.keys() if name.lower() in k.lower()] if len(keys) == 0: - raise ValueError('Unknown aseg_name!') - tmp = np.any(np.array([aseg==fs_aseg_dict[k] for k in keys]), axis=0) + raise ValueError("Unknown aseg_name!") + tmp = np.any(np.array([aseg == fs_aseg_dict[k] for k in keys]), axis=0) mask = np.logical_or(mask, tmp) if xfmname is not None: - mask = anat2epispace(mask.astype(float), subject, xfmname, order=order, **kwargs) + mask = anat2epispace( + mask.astype(float), subject, xfmname, order=order, **kwargs + ) if threshold is not None: mask = mask > threshold return mask -def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_lr=False, - allow_overlap=False, fail_for_missing_rois=True, exclude_empty_rois=False, - threshold=None, return_dict=True, overlay_file=None): +def get_roi_masks( + subject, + xfmname, + roi_list=None, + gm_sampler="cortical", + split_lr=False, + allow_overlap=False, + fail_for_missing_rois=True, + exclude_empty_rois=False, + threshold=None, + return_dict=True, + overlay_file=None, +): """Return a dictionary of roi masks This function returns a single 3D array with a separate numerical index for each ROI, @@ -748,13 +835,17 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ 'thin' as your `gm_sampler`. """ # Convert mapper names to pycortex sampler types - mapper_dict = {'cortical-conservative':'nearest', - 'cortical-liberal':'line_nearest'} + mapper_dict = { + "cortical-conservative": "nearest", + "cortical-liberal": "line_nearest", + } # Method use_mapper = gm_sampler in mapper_dict - use_cortex_mask = (gm_sampler in ('cortical', 'thick', 'thin')) or not isinstance(gm_sampler, str) + use_cortex_mask = (gm_sampler in ("cortical", "thick", "thin")) or not isinstance( + gm_sampler, str + ) if not (use_mapper or use_cortex_mask): - raise ValueError('Unknown gray matter sampler (gm_sampler)!') + raise ValueError("Unknown gray matter sampler (gm_sampler)!") # Initialize roi_voxels = {} pct_coverage = {} @@ -773,18 +864,28 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ roi_verts = get_roi_verts(subject, mask=use_mapper, overlay_file=overlay_file) roi_list = list(roi_verts.keys()) else: - tmp_list = [r for r in roi_list if not r=='Cortex'] + tmp_list = [r for r in roi_list if not r == "Cortex"] try: - roi_verts = get_roi_verts(subject, roi=tmp_list, mask=use_mapper, overlay_file=overlay_file) + roi_verts = get_roi_verts( + subject, roi=tmp_list, mask=use_mapper, overlay_file=overlay_file + ) except KeyError as key: if fail_for_missing_rois: - raise KeyError("Requested ROI {} not found in overlays.svg!".format(key)) + raise KeyError( + "Requested ROI {} not found in overlays.svg!".format(key) + ) else: - roi_verts = get_roi_verts(subject, roi=None, mask=use_mapper, overlay_file=overlay_file) - missing = [r for r in roi_list if not r in roi_verts.keys()+['Cortex']] - roi_verts = dict((roi, verts) for roi, verts in roi_verts.items() if roi in roi_list) - roi_list = list(set(roi_list)-set(missing)) - print('Requested ROI(s) {} not found in overlays.svg!'.format(missing)) + roi_verts = get_roi_verts( + subject, roi=None, mask=use_mapper, overlay_file=overlay_file + ) + missing = [ + r for r in roi_list if not r in roi_verts.keys() + ["Cortex"] + ] + roi_verts = dict( + (roi, verts) for roi, verts in roi_verts.items() if roi in roi_list + ) + roi_list = list(set(roi_list) - set(missing)) + print("Requested ROI(s) {} not found in overlays.svg!".format(missing)) # Get (a) indices for nearest vertex to each voxel # and (b) distance from each voxel to nearest vertex in fiducial surface if (use_cortex_mask or split_lr) or (not return_dict): @@ -799,7 +900,7 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ # Loop over ROIs to map vertices to volume, using mapper or cortex mask + vertex indices for roi in roi_list: if roi not in roi_verts: - if not roi=='Cortex': + if not roi == "Cortex": print("ROI {} not found...".format(roi)) continue if use_mapper: @@ -808,10 +909,14 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ if threshold is not None: roi_voxels[roi] = roi_voxels[roi] > threshold # Check for partial / empty rois: - vert_in_scan = np.hstack([np.array((m>0).sum(1)).flatten() for m in mapper.masks]) + vert_in_scan = np.hstack( + [np.array((m > 0).sum(1)).flatten() for m in mapper.masks] + ) vert_in_scan = vert_in_scan[roi_verts[roi]] elif use_cortex_mask: - vox_in_roi = np.in1d(vox_idx.flatten(), roi_verts[roi]).reshape(vox_idx.shape) + vox_in_roi = np.in1d(vox_idx.flatten(), roi_verts[roi]).reshape( + vox_idx.shape + ) roi_voxels[roi] = vox_in_roi & cortex_mask # This is not accurate... because vox_idx only contains the indices of the *nearest* # vertex to each voxel, it excludes many vertices. I can't think of a way to compute @@ -820,58 +925,69 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ # Compute ROI coverage pct_coverage[roi] = vert_in_scan.mean() * 100 if use_mapper: - print("Found %0.2f%% of %s"%(pct_coverage[roi], roi)) + print("Found %0.2f%% of %s" % (pct_coverage[roi], roi)) # Create cortex mask all_mask = np.array(list(roi_voxels.values())).sum(0) - if 'Cortex' in roi_list: + if "Cortex" in roi_list: if use_mapper: # cortex_mask isn't defined / exactly definable if you're using a mapper - print("Cortex roi not included b/c currently not compatible with your selection for gm_sampler") - _ = roi_list.pop(roi_list.index('Cortex')) + print( + "Cortex roi not included b/c currently not compatible with your selection for gm_sampler" + ) + _ = roi_list.pop(roi_list.index("Cortex")) else: - roi_voxels['Cortex'] = (all_mask==0) & cortex_mask + roi_voxels["Cortex"] = (all_mask == 0) & cortex_mask # Optionally cull voxels assigned to > 1 ROI due to partly overlapping ROI splines # in inkscape overlays.svg file: if not allow_overlap: - print('Cutting {} overlapping voxels (should be < ~50)'.format(np.sum(all_mask > 1))) + print( + "Cutting {} overlapping voxels (should be < ~50)".format( + np.sum(all_mask > 1) + ) + ) for roi in roi_list: roi_voxels[roi][all_mask > 1] = False # Split left / right hemispheres if desired if split_lr: # Use the fiducial surface because we need to have all vertices - left_verts, _ = db.get_surf(subject, "fiducial", merge=False, nudge=True) + left_verts, _ = db.get_surf(subject, "fiducial", merge=False, nudge=True) left_mask = vox_idx < len(np.unique(left_verts[1])) right_mask = np.logical_not(left_mask) roi_voxels_lr = {} for roi in roi_list: # roi_voxels may contain float values if using a mapper, therefore we need # to manually set the voxels in the other hemisphere to False. Then we let - # numpy do the conversion False -> 0. - roi_voxels_lr[roi + '_L'] = copy.copy(roi_voxels[roi]) - roi_voxels_lr[roi + '_L'][right_mask] = False - roi_voxels_lr[roi + '_R'] = copy.copy(roi_voxels[roi]) - roi_voxels_lr[roi + '_R'][left_mask] = False + # numpy do the conversion False -> 0. + roi_voxels_lr[roi + "_L"] = copy.copy(roi_voxels[roi]) + roi_voxels_lr[roi + "_L"][right_mask] = False + roi_voxels_lr[roi + "_R"] = copy.copy(roi_voxels[roi]) + roi_voxels_lr[roi + "_R"][left_mask] = False output = roi_voxels_lr else: output = roi_voxels # Check percent coverage / optionally cull empty ROIs - for roi in set(roi_list)-set(['Cortex']): + for roi in set(roi_list) - set(["Cortex"]): if pct_coverage[roi] < 100: # if not np.any(mask) : reject ROI - if pct_coverage[roi]==0: - warnings.warn('ROI %s is entirely missing from your scan protocol!'%(roi)) + if pct_coverage[roi] == 0: + warnings.warn( + "ROI %s is entirely missing from your scan protocol!" % (roi) + ) if exclude_empty_rois: if split_lr: - _ = output.pop(roi+'_L') - _ = output.pop(roi+'_R') + _ = output.pop(roi + "_L") + _ = output.pop(roi + "_R") else: _ = output.pop(roi) else: # I think this is the only one for which this works correctly... - if gm_sampler=='cortical-conservative': - warnings.warn('ROI %s is only %0.2f%% contained in your scan protocol!'%(roi, pct_coverage[roi])) + if gm_sampler == "cortical-conservative": + warnings.warn( + "ROI %s is only %0.2f%% contained in your scan protocol!" + % (roi, pct_coverage[roi]) + ) # Support alternative outputs for backward compatibility if return_dict: @@ -885,6 +1001,7 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_ idx_vol[left_mask] *= -1 return idx_vol, idx_labels + def get_dropout(subject: str, xfmname: str, power: float = 20): """Create a dropout Volume showing where EPI signal is very low. @@ -909,13 +1026,15 @@ def get_dropout(subject: str, xfmname: str, power: float = 20): if rawdata.ndim > 3: rawdata = rawdata.mean(0) - rawdata[rawdata==0] = np.mean(rawdata[rawdata!=0]) + rawdata[rawdata == 0] = np.mean(rawdata[rawdata != 0]) normdata = (rawdata - rawdata.min()) / (rawdata.max() - rawdata.min()) normdata = (1 - normdata) ** power from .dataset import Volume + return Volume(normdata, subject, xfmname) + def make_movie(stim, outfile, fps=15, size="640x480"): """Makes an .ogv movie @@ -939,10 +1058,12 @@ def make_movie(stim, outfile, fps=15, size="640x480"): """ import shlex import subprocess as sp + cmd = "ffmpeg -r {fps} -i {infile} -b 4800k -g 30 -s {size} -vcodec libtheora {outfile}.ogv" fcmd = cmd.format(infile=stim, size=size, fps=fps, outfile=outfile) sp.call(shlex.split(fcmd)) + def vertex_to_voxel(subject): # Am I deprecated in favor of mappers??? Maybe? """ Parameters @@ -975,28 +1096,30 @@ def vertex_to_voxel(subject): # Am I deprecated in favor of mappers??? Maybe? def _set_edge_distance_graph_attribute(graph, pts, polys): - ''' + """ adds the attribute 'edge distance' to a graph - ''' + """ import networkx as nx l2_distance = lambda v1, v2: np.linalg.norm(pts[v1] - pts[v2]) - heuristic = l2_distance # A* heuristic + heuristic = l2_distance # A* heuristic - if not nx.get_edge_attributes(graph, 'distance'): # Add edge distances as an attribute to this graph if it isn't there + if not nx.get_edge_attributes( + graph, "distance" + ): # Add edge distances as an attribute to this graph if it isn't there edge_distances = dict() - for x,y,z in polys: - edge_distances[(x,y)] = heuristic(x,y) - edge_distances[(y,x)] = heuristic(y,x) - edge_distances[(y,z)] = heuristic(y,z) - edge_distances[(z,y)] = heuristic(z,y) - edge_distances[(x,z)] = heuristic(x,z) - edge_distances[(z,x)] = heuristic(z,x) - nx.set_edge_attributes(graph, edge_distances, name='distance') + for x, y, z in polys: + edge_distances[(x, y)] = heuristic(x, y) + edge_distances[(y, x)] = heuristic(y, x) + edge_distances[(y, z)] = heuristic(y, z) + edge_distances[(z, y)] = heuristic(z, y) + edge_distances[(x, z)] = heuristic(x, z) + edge_distances[(z, x)] = heuristic(z, x) + nx.set_edge_attributes(graph, edge_distances, name="distance") def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True): - '''Return voxels that are shared by multiple vertices, and for each such voxel, + """Return voxels that are shared by multiple vertices, and for each such voxel, also returns the mutually farthest pair of vertices mapping to the voxel Parameters ---------- @@ -1017,18 +1140,21 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True) vox_vert_array: np.array, array of dimensions # voxels X 3, columns being: (vox_idx, farthest_pair[0], farthest_pair[1]) - ''' + """ import networkx as nx from scipy.sparse import find as sparse_find - Lmask, Rmask = get_mapper(subject, xfmname).masks # Get masks for left and right hemisphere - if hemi == 'both': - hemispheres = ['lh', 'rh'] + + Lmask, Rmask = get_mapper( + subject, xfmname + ).masks # Get masks for left and right hemisphere + if hemi == "both": + hemispheres = ["lh", "rh"] else: hemispheres = [hemi] out = [] for hem in hemispheres: - if hem == 'lh': + if hem == "lh": mask = Lmask else: mask = Rmask @@ -1036,8 +1162,10 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True) all_voxels = mask.tolil().transpose().rows # Map from voxels to verts vert_to_vox_map = dict(zip(*(sparse_find(mask)[:2]))) # From verts to vox - pts_fid, polys_fid = db.get_surf(subject, 'fiducial', hem) # Get the fiducial surface - surf = Surface(pts_fid, polys_fid) #Get the fiducial surface + pts_fid, polys_fid = db.get_surf( + subject, "fiducial", hem + ) # Get the fiducial surface + surf = Surface(pts_fid, polys_fid) # Get the fiducial surface graph = surf.graph _set_edge_distance_graph_attribute(graph, pts_fid, polys_fid) @@ -1046,32 +1174,44 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True) heuristic = l2_distance # A* heuristic if use_astar: - shortest_path = lambda a, b: nx.astar_path(graph, a, b, heuristic=heuristic, weight='distance') # Find approximate shortest paths using A* search + shortest_path = lambda a, b: nx.astar_path( + graph, a, b, heuristic=heuristic, weight="distance" + ) # Find approximate shortest paths using A* search else: - shortest_path = surf.geodesic_path # Find shortest paths using geodesic distances + shortest_path = ( + surf.geodesic_path + ) # Find shortest paths using geodesic distances vox_vert_list = [] for vox_idx, vox in enumerate(all_voxels): if len(vox) > 1: # If the voxel maps to multiple vertices vox = np.array(vox).astype(int) - for v1 in range(vox.size-1): + for v1 in range(vox.size - 1): vert1 = vox[v1] if vert1 in vert_to_vox_map: # If the vertex is a valid vertex - for v2 in range(v1+1, vox.size): + for v2 in range(v1 + 1, vox.size): vert2 = vox[v2] - if vert2 in vert_to_vox_map: # If the vertex is a valid vertex + if ( + vert2 in vert_to_vox_map + ): # If the vertex is a valid vertex path = shortest_path(vert1, vert2) # Test whether any vertex in path goes out of the voxel - stays_in_voxel = all([(v in vert_to_vox_map) and (vert_to_vox_map[v] == vox_idx) for v in path]) + stays_in_voxel = all( + [ + (v in vert_to_vox_map) + and (vert_to_vox_map[v] == vox_idx) + for v in path + ] + ) if not stays_in_voxel: vox_vert_list.append([vox_idx, vert1, vert2]) - tmp = np.array(vox_vert_list) + tmp = np.array(vox_vert_list) # Add offset for right hem voxels - if hem=='rh': + if hem == "rh": tmp[:, 1:3] += Lmask.shape[0] out.append(tmp) - if hemi in ('lh', 'rh'): + if hemi in ("lh", "rh"): return out[0] else: if merge: @@ -1096,13 +1236,18 @@ def load_sparse_array(fname, varname): conventions, so cannot be used to load arbitrary sparse arrays. """ import scipy.sparse + with h5py.File(fname) as hf: - data = (hf['%s_data'%varname], hf['%s_indices'%varname], hf['%s_indptr'%varname]) - sparsemat = scipy.sparse.csr_matrix(data, shape=hf['%s_shape'%varname]) + data = ( + hf["%s_data" % varname], + hf["%s_indices" % varname], + hf["%s_indptr" % varname], + ) + sparsemat = scipy.sparse.csr_matrix(data, shape=hf["%s_shape" % varname]) return sparsemat -def save_sparse_array(fname, data, varname, mode='a'): +def save_sparse_array(fname, data, varname, mode="a"): """Save a numpy sparse array to an hdf file Results in relatively smaller file size than numpy.savez @@ -1119,19 +1264,20 @@ def save_sparse_array(fname, data, varname, mode='a'): write / append mode set, one of ['w','a'] (passed to h5py.File()) """ import scipy.sparse + if not isinstance(data, scipy.sparse.csr.csr_matrix): data_ = scipy.sparse.csr_matrix(data) else: data_ = data with h5py.File(fname, mode=mode) as hf: # Save indices - hf.create_dataset(varname + '_indices', data=data_.indices, compression='gzip') + hf.create_dataset(varname + "_indices", data=data_.indices, compression="gzip") # Save data - hf.create_dataset(varname + '_data', data=data_.data, compression='gzip') + hf.create_dataset(varname + "_data", data=data_.data, compression="gzip") # Save indptr - hf.create_dataset(varname + '_indptr', data=data_.indptr, compression='gzip') + hf.create_dataset(varname + "_indptr", data=data_.indptr, compression="gzip") # Save shape - hf.create_dataset(varname + '_shape', data=data_.shape, compression='gzip') + hf.create_dataset(varname + "_shape", data=data_.shape, compression="gzip") def get_cmap(name): @@ -1149,10 +1295,11 @@ def get_cmap(name): """ import matplotlib.pyplot as plt from matplotlib import colors + # unknown colormap, test whether it's in pycortex colormaps - cmapdir = config.get('webgl', 'colormaps') + cmapdir = config.get("webgl", "colormaps") colormaps = os.listdir(cmapdir) - colormaps = sorted([c for c in colormaps if '.png' in c]) + colormaps = sorted([c for c in colormaps if ".png" in c]) colormaps = dict((c[:-4], os.path.join(cmapdir, c)) for c in colormaps) if name in colormaps: I = plt.imread(colormaps[name]) @@ -1165,15 +1312,16 @@ def get_cmap(name): try: cmap = plt.cm.get_cmap(name) except: - raise Exception('Unkown color map!') + raise Exception("Unknown color map!") return cmap + def add_cmap(cmap, name, cmapdir=None): """Add a colormap to pycortex. This stores a matplotlib colormap in the pycortex filestore, such that it can - be used in the webgl viewer in pycortex. See - https://matplotlib.org/stable/users/explain/colors/colormap-manipulation.html + be used in the webgl viewer in pycortex. See + https://matplotlib.org/stable/users/explain/colors/colormap-manipulation.html for more information about how to generate colormaps in matplotlib. Parameters @@ -1182,8 +1330,8 @@ def add_cmap(cmap, name, cmapdir=None): Color map to be saved name : str Name for colormap, e.g. 'jet', 'blue_to_yellow', etc. The name will be used - to generate a filename for the colormap stored in the pycortex store, - so avoid illegal characters for a filename. This name will also be used to + to generate a filename for the colormap stored in the pycortex store, + so avoid illegal characters for a filename. This name will also be used to specify this colormap in future calls to `cortex.quickflat.make_figure()` or `cortex.webgl.show()`. """ @@ -1199,8 +1347,9 @@ def add_cmap(cmap, name, cmapdir=None): plt.imsave(os.path.join(cmapdir, name), cmap_im, format="png") -def download_subject(subject_id='fsaverage', url=None, pycortex_store=None, - download_again=False): +def download_subject( + subject_id="fsaverage", url=None, pycortex_store=None, download_again=False +): """Download subjects to pycortex store Parameters @@ -1224,15 +1373,16 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None, warnings.warn( "{} is already present in the database. " "Set download_again to True if you wish to download " - "the subject again.".format(subject_id)) + "the subject again.".format(subject_id) + ) return # Map codes to URLs; more coming eventually id_to_url = dict( - fsaverage='https://ndownloader.figshare.com/files/17827577?private_link=4871247dce31e188e758', + fsaverage="https://ndownloader.figshare.com/files/17827577?private_link=4871247dce31e188e758", ) if url is None: if subject_id not in id_to_url: - raise ValueError('Unknown subject_id!') + raise ValueError("Unknown subject_id!") url = id_to_url[subject_id] # Setup pycortex store location if pycortex_store is None: @@ -1242,12 +1392,11 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None, # Download to temp dir print("Downloading from: {}".format(url)) with tempfile.TemporaryDirectory() as tmp_dir: - print('Downloading subject {} to {}'.format(subject_id, tmp_dir)) + print("Downloading subject {} to {}".format(subject_id, tmp_dir)) fnout, _ = urllib.request.urlretrieve( - url, - os.path.join(tmp_dir, f"{subject_id}.tar.gz") + url, os.path.join(tmp_dir, f"{subject_id}.tar.gz") ) - print(f'Done downloading to {fnout}') + print(f"Done downloading to {fnout}") # Un-tar to pycortex store with tarfile.open(fnout, "r:gz") as tar: print("Extracting subject {} to {}".format(subject_id, pycortex_store)) @@ -1259,51 +1408,56 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None, def rotate_flatmap(surf_id, theta, plot=False): """Rotate flatmap to be less V-shaped - + Parameters ---------- surf_id : str pycortex surface identifier theta : scalar - angle in degrees to rotate flatmaps (rotation is clockwise + angle in degrees to rotate flatmaps (rotation is clockwise for right hemisphere and counter-clockwise for left) plot : bool Whether to make a coarse plot to visualize the changes """ # Lazy load of matplotlib import matplotlib.pyplot as plt - paths = db.get_paths(surf_id)['surfs']['flat'] + + paths = db.get_paths(surf_id)["surfs"]["flat"] theta = np.radians(theta) if plot: fig, axs = plt.subplots(2, 2) - for j, hem in enumerate(('lh','rh')): + for j, hem in enumerate(("lh", "rh")): this_file = paths[hem] pts, polys = formats.read_gii(this_file) # Rotate clockwise (- rotation) for RH, counter-clockwise (+ rotation) for LH - if hem == 'rh': - rtheta = - theta + if hem == "rh": + rtheta = -theta else: rtheta = copy.copy(theta) - rotation_mat = np.array([[np.cos(rtheta), -np.sin(rtheta)], [np.sin(rtheta), np.cos(rtheta)]]) + rotation_mat = np.array( + [[np.cos(rtheta), -np.sin(rtheta)], [np.sin(rtheta), np.cos(rtheta)]] + ) rotated = rotation_mat.dot(pts[:, :2].T).T pts_new = pts.copy() pts_new[:, :2] = rotated new_file, bkup_num = copy.copy(this_file), 0 while os.path.exists(new_file): - new_file = this_file.replace('.gii', '_rotbkup%02d.gii'%bkup_num) + new_file = this_file.replace(".gii", "_rotbkup%02d.gii" % bkup_num) bkup_num += 1 - print('Backing up file at %s...' % new_file) + print("Backing up file at %s..." % new_file) shutil.copy(this_file, new_file) formats.write_gii(this_file, pts_new, polys) - print('Overwriting %s...' % this_file) + print("Overwriting %s..." % this_file) if plot: - axs[0,j].plot(*pts[::100, :2].T, marker='r.') - axs[0,j].axis('equal') - axs[1,j].plot(*pts_new[::100, :2].T, marker='b.') - axs[1,j].axis('equal') + axs[0, j].plot(*pts[::100, :2].T, marker="r.") + axs[0, j].axis("equal") + axs[1, j].plot(*pts_new[::100, :2].T, marker="b.") + axs[1, j].axis("equal") # Remove and back up overlays file - overlay_file = db.get_paths(surf_id)['overlays'] - shutil.copy(overlay_file, overlay_file.replace('.svg', '_rotbkup%02d.svg'%bkup_num)) + overlay_file = db.get_paths(surf_id)["overlays"] + shutil.copy( + overlay_file, overlay_file.replace(".svg", "_rotbkup%02d.svg" % bkup_num) + ) os.unlink(overlay_file) # Regenerate file svg = db.get_overlay(surf_id) diff --git a/cortex/webgl/resources/js/mriview.js b/cortex/webgl/resources/js/mriview.js index ff536df6c..1b6ec52cc 100644 --- a/cortex/webgl/resources/js/mriview.js +++ b/cortex/webgl/resources/js/mriview.js @@ -222,6 +222,37 @@ var mriview = (function(module) { } this.setData(data[0].name); + + // Populate the contours folder: overlay first, then mode and threshold (only once) + if (!this._contourUIAdded) { + var contourOptions = {"none": "none"}; + for (var dname in this.dataviews) { + if (this.dataviews[dname].vertex) { + contourOptions[dname] = dname; + } + } + this._contourOverlayName = "none"; + var viewer = this; + for (var i = 0; i < this.surfs.length; i++) { + (function(surf) { + surf.surf.loaded.done(function() { + var contoursFolder = surf.surf.ui.contours; + // Overlay dropdown first (if multiple vertex datasets) + if (Object.keys(contourOptions).length > 1) { + contoursFolder.add({ + overlay: {action:[viewer, "setContourOverlay", contourOptions]}, + }); + } + // Then mode and threshold + contoursFolder.add({ + mode: {action:[surf.surf, "setContourMode", {off:0, "contours only":1, "contours + fill":2, "colored contours":3, "colored + fill":4}]}, + threshold: {action:[surf.surf.uniforms.contourThreshold, "value", 0.001, 0.5]}, + }); + }); + })(this.surfs[i]); + } + this._contourUIAdded = true; + } }; module.Viewer.prototype.setData = function(name) { @@ -614,6 +645,78 @@ var mriview = (function(module) { this.playpause(); this.setData([datasets[(i+dir).mod(datasets.length)]]); }; + module.Viewer.prototype.setContourOverlay = function(name) { + if (name === undefined) + return this._contourOverlayName || "none"; + + if (name === "none" || name === null || name === 0) { + this._contourOverlayName = "none"; + this.contourOverlay = null; + for (var i = 0; i < this.surfs.length; i++) { + this.surfs[i].surf.uniforms.contourOverlay.value = 0; + } + this.schedule(); + return; + } + + var overlayView = this.dataviews[name]; + if (!overlayView) { + console.warn("setContourOverlay: dataset '" + name + "' not found. Available: " + + Object.keys(this.dataviews).join(", ")); + return; + } + if (!overlayView.vertex) { + console.warn("setContourOverlay: dataset '" + name + "' is not vertex data. " + + "Contour overlays require vertex (surface) data."); + return; + } + + this._contourOverlayName = name; + this.contourOverlay = name; + var overlayData = overlayView.data[0]; + + var viewer = this; + var applyOverlay = function() { + // Use frame 0 for contour overlay data. Parcellation overlays + // are typically single-frame (static labels). Multi-frame + // contour overlays are not currently supported. + var fframe = 0; + var verts = overlayData.verts[fframe]; + var verts1 = overlayData.verts[(fframe+1) % overlayData.verts.length]; + for (var i = 0; i < viewer.surfs.length; i++) { + var surf = viewer.surfs[i].surf; + surf.hemis.left.attributes.contourData0.array = verts[0].array; + surf.hemis.left.attributes.contourData0.needsUpdate = true; + surf.hemis.right.attributes.contourData0.array = verts[1].array; + surf.hemis.right.attributes.contourData0.needsUpdate = true; + surf.hemis.left.attributes.contourData1.array = verts1[0].array; + surf.hemis.left.attributes.contourData1.needsUpdate = true; + surf.hemis.right.attributes.contourData1.array = verts1[1].array; + surf.hemis.right.attributes.contourData1.needsUpdate = true; + surf.uniforms.contourOverlay.value = 1; + // Set vmin/vmax and colormap for colored contour lookup + surf.uniforms.contourVmin.value = overlayView.vmin[0].value[0]; + surf.uniforms.contourVmax.value = overlayView.vmax[0].value[0]; + surf.uniforms.contourColormap.value = overlayView.cmap[0].value; + } + viewer.schedule(); + }; + + // Data may already be available (verts populated via progress callback) + // even though loaded.state() is "pending" (resolve() is never called + // for VertexData). Check verts directly. + if (overlayData.verts.length > 0) { + applyOverlay(); + } else { + // Data not yet available — wait for progress + overlayData.loaded.progress(function() { + if (overlayData.verts.length > 0) { + applyOverlay(); + } + }); + } + }; + module.Viewer.prototype.rmData = function(name) { delete this.datasets[name]; $(this.object).find("#datasets li").each(function() { diff --git a/cortex/webgl/resources/js/mriview_surface.js b/cortex/webgl/resources/js/mriview_surface.js index 5b1a788ea..c6d8b9a41 100644 --- a/cortex/webgl/resources/js/mriview_surface.js +++ b/cortex/webgl/resources/js/mriview_surface.js @@ -71,6 +71,15 @@ var mriview = (function(module) { contrast: { type:'f', value:parseFloat(viewopts.contrast)}, extratex: { type:'t', value:null}, + // Contour rendering + contourMode: { type:'f', value: 0 }, + contourThreshold: { type:'f', value: 0.01 }, + contourColor: { type:'v3', value: new THREE.Vector3(0, 0, 0) }, + contourOverlay: { type:'f', value: 0 }, + contourVmin: { type:'f', value: 0 }, + contourVmax: { type:'f', value: 1 }, + contourColormap: { type:'t', value: new THREE.DataTexture(new Uint8Array([0,0,0,255]), 1, 1, THREE.RGBAFormat) }, + // screen: { type:'t', value:this.volumebuf}, // screen_size:{ type:'v2', value:new THREE.Vector2(100, 100)}, } @@ -110,7 +119,8 @@ var mriview = (function(module) { sampler: {action:[this, "setSampler", ["nearest", "trilinear"]]}, uniform_illumination: {action:[this, "setUniformIllumination"]}, }); - + + this.ui.addFolder("contours", true); this.ui.addFolder("curvature", true).add({ brightness: {action:[this.uniforms.brightness, "value", 0, 1]}, @@ -257,6 +267,11 @@ var mriview = (function(module) { hemi.addAttribute("data2", new THREE.BufferAttribute(new Float32Array(), 1)); hemi.addAttribute("data3", new THREE.BufferAttribute(new Float32Array(), 1)); + //Queue contour overlay data attributes (pre-sized to vertex count for proper WebGL buffer allocation) + var nVerts = hemi.attributes.position.array.length / hemi.attributes.position.itemSize; + hemi.addAttribute("contourData0", new THREE.BufferAttribute(new Float32Array(nVerts), 1)); + hemi.addAttribute("contourData1", new THREE.BufferAttribute(new Float32Array(nVerts), 1)); + hemi.dynamic = true; var pivots = {back:new THREE.Group(), front:new THREE.Group()}; pivots.front.add(pivots.back); @@ -854,6 +869,13 @@ var mriview = (function(module) { this.object.add(this.mesh); } + module.Surface.prototype.setContourMode = function(val) { + if (val === undefined) + return this.uniforms.contourMode.value; + this.uniforms.contourMode.value = parseFloat(val); + viewer.schedule(); + }; + module.Surface.prototype.setUniformIllumination = function(val) { if (val === undefined) return this.uniforms.emissive.value.x == 1; // Check current state diff --git a/cortex/webgl/resources/js/shaderlib.js b/cortex/webgl/resources/js/shaderlib.js index 9eaeb73a7..ea7f77c19 100644 --- a/cortex/webgl/resources/js/shaderlib.js +++ b/cortex/webgl/resources/js/shaderlib.js @@ -382,6 +382,16 @@ var Shaderlib = (function() { "varying vec3 vWorldPosition;", // "varying float vDrop;", + // Contour overlay attributes and varyings + "attribute float contourData0;", + "attribute float contourData1;", + "varying float vContourDataValue;", + "varying vec4 vContourColor;", + "uniform float contourVmin;", + "uniform float contourVmax;", + "uniform sampler2D contourColormap;", + "uniform float framemix;", + "varying vec3 vPos_x[2];", "#ifdef TWOD", "varying vec3 vPos_y[2];", @@ -446,6 +456,12 @@ var Shaderlib = (function() { "gl_Position = projectionMatrix * modelViewMatrix * vec4( pos, 1.0 );", "vWorldPosition = pos;", + + // Contour overlay data + "vContourDataValue = mix(contourData0, contourData1, framemix);", + "float contourRange = contourVmax - contourVmin;", + "float contourNorm = contourRange > 0.0 ? clamp((vContourDataValue - contourVmin) / contourRange, 0.0, 1.0) : 0.0;", + "vContourColor = texture2D(contourColormap, vec2(contourNorm, 0.0));", "}" ].join("\n"); @@ -504,7 +520,15 @@ var Shaderlib = (function() { "varying float vMedial;", "varying float vThickmix;", "varying vec3 vWorldPosition;", // the x,y,z coordinates of this pixel - + + // Contour rendering uniforms and varyings + "varying float vContourDataValue;", + "varying vec4 vContourColor;", + "uniform float contourMode;", + "uniform float contourThreshold;", + "uniform vec3 contourColor;", + "uniform float contourOverlay;", + utils.standard_frag_vars, utils.rand, utils.edge, @@ -647,8 +671,41 @@ var Shaderlib = (function() { "vec4 tColor = (1. - step(.001, vMedial)) * texture2D(extratex, vUv);", "#endif", + // Contour edge detection (uses overlay vertex data even for pixel-shaded volumes) + "float contourEdge = contourOverlay > 0.5 ? fwidth(vContourDataValue) : 0.0;", + "bool isBorder = contourEdge > contourThreshold;", + "gl_FragColor = cColor;", - "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + // contourMode: 0=off, 1=contours only, 2=contours+fill, + // 3=colored contours only, 4=colored contours+fill + "if (contourMode < 0.5) {", + // Mode 0: normal rendering, no contours + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "} else if (contourMode < 1.5) {", + // Mode 1: contours only (interior = curvature, no data fill) + "if (isBorder) {", + "vec4 borderColor = vec4(contourColor, 1.0);", + "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;", + "}", + "} else if (contourMode < 2.5) {", + // Mode 2: data + contour borders on top + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "if (isBorder) {", + "vec4 borderColor = vec4(contourColor, 1.0);", + "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;", + "}", + "} else if (contourMode < 3.5) {", + // Mode 3: colored contours only (no data fill) + "if (isBorder) {", + "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;", + "}", + "} else {", + // Mode 4: data + colored contour borders on top + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "if (isBorder) {", + "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;", + "}", + "}", // "gl_FragColor = hColor + (1.-hColor.a)*gl_FragColor;", "#ifdef ROI_RENDER", "gl_FragColor = rColor + (1.-rColor.a)*gl_FragColor;", @@ -674,6 +731,9 @@ var Shaderlib = (function() { attributes.flatBumpNorms = { type: 'v3', value:null }; attributes.flatheight = { type: 'f', value:null }; } + attributes['contourData0'] = {type:'f', value:null}; + attributes['contourData1'] = {type:'f', value:null}; + for (var i = 0; i < morphs-1; i++) { attributes['mixSurfs'+i] = { type:'v4', value:null}; attributes['mixNorms'+i] = { type:'v3', value:null}; @@ -730,6 +790,14 @@ var Shaderlib = (function() { "varying vec2 vUv;", "varying float vCurv;", "varying float vMedial;", + "varying float vDataValue;", + "attribute float contourData0;", + "attribute float contourData1;", + "varying float vContourDataValue;", + "varying vec4 vContourColor;", + "uniform float contourVmin;", + "uniform float contourVmax;", + "uniform sampler2D contourColormap;", // "varying float vDrop;", utils.mixer(morphs), @@ -743,6 +811,7 @@ var Shaderlib = (function() { "#ifdef RGBCOLORS", "vColor = mix(data0, data1, framemix);", + "vDataValue = 0.0;", "#else", "vec2 cuv;", // "vValue.x = (mix(data0, data1, framemix) - vmin[0]) / (vmax[0] - vmin[0]);", @@ -752,7 +821,13 @@ var Shaderlib = (function() { "cuv.y = (mix(data2, data3, framemix) - vmin[1]) / (vmax[1] - vmin[1]);", "#endif", "vColor = texture2D(colormap, cuv);", + "vDataValue = mix(data0, data1, framemix);", "#endif", + "vContourDataValue = mix(contourData0, contourData1, framemix);", + // Look up contour color in the overlay's own colormap + "float contourRange = contourVmax - contourVmin;", + "float contourNorm = contourRange > 0.0 ? clamp((vContourDataValue - contourVmin) / contourRange, 0.0, 1.0) : 0.0;", + "vContourColor = texture2D(contourColormap, vec2(contourNorm, 0.0));", "#ifdef CORTSHEET", "vec3 mpos = mix(position, wm.xyz, use_thickmix);", @@ -806,9 +881,19 @@ var Shaderlib = (function() { // "varying float vDrop;", "varying float vCurv;", "varying float vMedial;", + "varying float vDataValue;", + "varying float vContourDataValue;", + "varying vec4 vContourColor;", "uniform float thickmix;", // utils.thickmixer, + // Contour rendering uniforms + // 0=off, 1=contours only, 2=contours+fill, 3=colored contours only, 4=colored contours+fill + "uniform float contourMode;", + "uniform float contourThreshold;", + "uniform vec3 contourColor;", + "uniform float contourOverlay;", // 0=use self data, 1=use overlay data + utils.standard_frag_vars, // "#ifdef RGBCOLORS", @@ -836,13 +921,41 @@ var Shaderlib = (function() { "vec4 tColor = (1. - step(.001, vMedial)) * texture2D(extratex, vUv);", "#endif", - // "#ifndef RGBCOLORS", - // "vec4 vColor = texture2D(colormap, vValue);", - // "#endif", - + // Contour edge detection + "float contourEdge = contourOverlay > 0.5 ? fwidth(vContourDataValue) : fwidth(vDataValue);", + "bool isBorder = contourEdge > contourThreshold;", "gl_FragColor = cColor;", - "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + // contourMode: 0=off, 1=contours only, 2=contours+fill, + // 3=colored contours only, 4=colored contours+fill + "if (contourMode < 0.5) {", + // Mode 0: normal rendering, no contours + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "} else if (contourMode < 1.5) {", + // Mode 1: contours only (interior = curvature) + "if (isBorder) {", + "vec4 borderColor = vec4(contourColor, 1.0);", + "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;", + "}", + "} else if (contourMode < 2.5) {", + // Mode 2: data + contour borders on top + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "if (isBorder) {", + "vec4 borderColor = vec4(contourColor, 1.0);", + "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;", + "}", + "} else if (contourMode < 3.5) {", + // Mode 3: colored contours only (border color from overlay colormap) + "if (isBorder) {", + "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;", + "}", + "} else {", + // Mode 4: data + colored contour borders on top + "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;", + "if (isBorder) {", + "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;", + "}", + "}", //"gl_FragColor = vec4(1., 0., 0., 1.);", // "gl_FragColor = hColor + (1.-hColor.a)*gl_FragColor;", "#ifdef ROI_RENDER", @@ -866,6 +979,9 @@ var Shaderlib = (function() { for (var i = 0; i < 4; i++) attributes['data'+i] = {type:opts.rgb ? 'v4':'f', value:null}; + attributes['contourData0'] = {type:'f', value:null}; + attributes['contourData1'] = {type:'f', value:null}; + for (var i = 0; i < morphs-1; i++) { attributes['mixSurfs'+i] = { type:'v4', value:null }; attributes['mixNorms'+i] = { type:'v3', value:null }; diff --git a/cortex/webgl/view.py b/cortex/webgl/view.py index d32ad7916..242ccd2cc 100644 --- a/cortex/webgl/view.py +++ b/cortex/webgl/view.py @@ -28,19 +28,25 @@ from .FallbackLoader import FallbackLoader try: - cmapdir = options.config.get('webgl', 'colormaps') + cmapdir = options.config.get("webgl", "colormaps") if not os.path.exists(cmapdir): - raise Exception("Colormap directory (%s) does not exist"%cmapdir) + raise Exception("Colormap directory (%s) does not exist" % cmapdir) except NoOptionError: cmapdir = os.path.join(options.config.get("basic", "filestore"), "colormaps") if not os.path.exists(cmapdir): - raise Exception("Colormap directory was not defined in the config file and the default (%s) does not exist"%cmapdir) + raise Exception( + "Colormap directory was not defined in the config file and the default (%s) does not exist" + % cmapdir + ) domain_name = options.config.get("webgl", "domain_name") colormaps = glob.glob(os.path.join(cmapdir, "*.png")) -colormaps = [(os.path.splitext(os.path.split(cm)[1])[0], serve.make_base64(cm)) - for cm in sorted(colormaps)] +colormaps = [ + (os.path.splitext(os.path.split(cm)[1])[0], serve.make_base64(cm)) + for cm in sorted(colormaps) +] + def make_static( outpath, @@ -130,7 +136,7 @@ def make_static( Smoothness of curvature overlay. Default None, which uses the value specified in the config file. surface_specularity : float or None, optional - Specularity of surfaces visualized with the WebGL viewer. + Specularity of surfaces visualized with the WebGL viewer. Default None, which uses the value specified in the config file under `webgl_viewopts.specularity`. **kwargs @@ -285,24 +291,24 @@ def make_static( def show( data: Union[dataset.Dataset, dataset.Dataview], - autoclose: Optional[bool]=None, - open_browser: Optional[bool]=None, - port: Optional[int]=None, - pickerfun: Optional[Callable[[tuple[int, int, int], int, str], None]]=None, - recache: bool=False, - template: str="mixer.html", - overlays_available: Optional[tuple[str, ...]]=None, - overlays_visible: Optional[tuple[str, ...]]=("rois", "sulci"), - labels_visible: Optional[tuple[str, ...]]=("rois",), - types: Optional[tuple[str, ...]]=("inflated",), - overlay_file: Optional[str]=None, - curvature_brightness: Optional[float]=None, - curvature_contrast: Optional[float]=None, - curvature_smoothness: Optional[float]=None, - surface_specularity: Optional[float]=None, - title: str="Brain", - layout: Optional[str]=None, - display_url: bool=True, + autoclose: Optional[bool] = None, + open_browser: Optional[bool] = None, + port: Optional[int] = None, + pickerfun: Optional[Callable[[tuple[int, int, int], int, str], None]] = None, + recache: bool = False, + template: str = "mixer.html", + overlays_available: Optional[tuple[str, ...]] = None, + overlays_visible: Optional[tuple[str, ...]] = ("rois", "sulci"), + labels_visible: Optional[tuple[str, ...]] = ("rois",), + types: Optional[tuple[str, ...]] = ("inflated",), + overlay_file: Optional[str] = None, + curvature_brightness: Optional[float] = None, + curvature_contrast: Optional[float] = None, + curvature_smoothness: Optional[float] = None, + surface_specularity: Optional[float] = None, + title: str = "Brain", + layout: Optional[str] = None, + display_url: bool = True, **kwargs, ): """ @@ -338,7 +344,7 @@ def show( overlays_available : tuple, optional Overlays available in the viewer. If None, then all overlay layers of the svg file will be potentially available in the viewer (whether initially - visible or not). + visible or not). overlays_visible : tuple, optional The listed overlay layers will be set visible by default. Layers not listed here will be hidden by default (but can be enabled in the viewer GUI). @@ -366,7 +372,7 @@ def show( Smoothness of curvature overlay. Default None, which uses the value specified in the config file. surface_specularity : float or None, optional - Specularity of surfaces visualized with the WebGL viewer. + Specularity of surfaces visualized with the WebGL viewer. Default None, which uses the value specified in the config file under `webgl_viewopts.specularity`. title : str, optional @@ -387,52 +393,61 @@ def show( # populate default webshow args if autoclose is None: - autoclose = options.config.get('webshow', 'autoclose', fallback='true') == 'true' + autoclose = ( + options.config.get("webshow", "autoclose", fallback="true") == "true" + ) if open_browser is None: - open_browser = options.config.get('webshow', 'open_browser', fallback='true') == 'true' + open_browser = ( + options.config.get("webshow", "open_browser", fallback="true") == "true" + ) data = dataset.normalize(data) if not isinstance(data, dataset.Dataset): data = dataset.Dataset(data=data) - html = FallbackLoader([os.path.split(os.path.abspath(template))[0], serve.cwd]).load(template) + html = FallbackLoader( + [os.path.split(os.path.abspath(template))[0], serve.cwd] + ).load(template) db.auxfile = data - #Extract the list of stimuli, for special-casing + # Extract the list of stimuli, for special-casing stims: dict[str, str] = dict() for name, view in data: - if 'stim' in view.attrs and os.path.exists(view.attrs['stim']): - sname = os.path.split(view.attrs['stim'])[1] - stims[sname] = view.attrs['stim'] + if "stim" in view.attrs and os.path.exists(view.attrs["stim"]): + sname = os.path.split(view.attrs["stim"])[1] + stims[sname] = view.attrs["stim"] package = Package(data) metadata = json.dumps(package.metadata()) images = package.images subjects = list(package.subjects) - ctmargs = dict(method='mg2', level=9, recache=recache, - external_svg=overlay_file, overlays_available=overlays_available) - ctms = dict((subj, utils.get_ctmpack(subj, types, **ctmargs)) - for subj in subjects) + ctmargs = dict( + method="mg2", + level=9, + recache=recache, + external_svg=overlay_file, + overlays_available=overlays_available, + ) + ctms = dict((subj, utils.get_ctmpack(subj, types, **ctmargs)) for subj in subjects) package.reorder(ctms) - subjectjs = json.dumps(dict((subj, "ctm/%s/"%subj) for subj in subjects)) + subjectjs = json.dumps(dict((subj, "ctm/%s/" % subj) for subj in subjects)) db.auxfile = None - - linear = lambda x, y, m: (1.-m)*x + m*y + linear = lambda x, y, m: (1.0 - m) * x + m * y mixes = dict( linear=linear, - smoothstep=(lambda x, y, m: linear(x, y, 3*m**2 - 2*m**3)), - smootherstep=(lambda x, y, m: linear(x, y, 6*m**5 - 15*m**4 + 10*m**3)) + smoothstep=(lambda x, y, m: linear(x, y, 3 * m**2 - 2 * m**3)), + smootherstep=(lambda x, y, m: linear(x, y, 6 * m**5 - 15 * m**4 + 10 * m**3)), ) post_name: Queue[str] = Queue() # Put together all view options - my_viewopts: dict[str, Any] = dict(options.config.items('webgl_viewopts')) - my_viewopts['overlays_visible'] = overlays_visible - my_viewopts['labels_visible'] = labels_visible + my_viewopts: dict[str, Any] = dict(options.config.items("webgl_viewopts")) + my_viewopts["overlays_visible"] = overlays_visible + my_viewopts["labels_visible"] = labels_visible my_viewopts["brightness"] = ( options.config.get("curvature", "brightness") if curvature_brightness is None @@ -455,7 +470,7 @@ def show( ) for sec in options.config.sections(): - if 'paths' in sec or 'labels' in sec: + if "paths" in sec or "labels" in sec: my_viewopts[sec] = dict(options.config.items(sec)) if pickerfun is None: @@ -463,8 +478,8 @@ def show( class CTMHandler(web.RequestHandler): def get(self, path: str): - subj, path = path.split('/') - if path == '': + subj, path = path.split("/") + if path == "": self.set_header("Content-Type", "application/json") self.write(open(ctms[subj]).read()) else: @@ -473,14 +488,14 @@ def get(self, path: str): if mtype is None: mtype = "application/octet-stream" self.set_header("Content-Type", mtype) - self.write(open(os.path.join(fpath, path), 'rb').read()) + self.write(open(os.path.join(fpath, path), "rb").read()) class DataHandler(web.RequestHandler): def get(self, path: str): path = path.strip("/") frame: Union[int, str] try: - dataname, frame = path.split('/') + dataname, frame = path.split("/") except ValueError: dataname = path frame = 0 @@ -492,15 +507,21 @@ def get(self, path: str): else: self.set_header("Content-Type", "image/png") - if 'Range' in self.request.headers: + if "Range" in self.request.headers: self.set_status(206) - rangestr = self.request.headers['Range'].split('=')[1] - start, end = [ int(i) if len(i) > 0 else None for i in rangestr.split('-') ] - - clenheader = 'bytes %s-%s/%s' % (start, end or len(dataimg), len(dataimg) ) - self.set_header('Content-Range', clenheader) - self.set_header('Content-Length', end-start+1) - self.write(dataimg[start:end+1]) + rangestr = self.request.headers["Range"].split("=")[1] + start, end = [ + int(i) if len(i) > 0 else None for i in rangestr.split("-") + ] + + clenheader = "bytes %s-%s/%s" % ( + start, + end or len(dataimg), + len(dataimg), + ) + self.set_header("Content-Range", clenheader) + self.set_header("Content-Length", end - start + 1) + self.write(dataimg[start : end + 1]) else: self.write(dataimg) else: @@ -521,24 +542,26 @@ def get(self, path: str): class StaticHandler(web.StaticFileHandler): def initialize(self): - self.root = '' + self.root = "" class MixerHandler(web.RequestHandler): def get(self): self.set_header("Content-Type", "text/html") - generated = html.generate(data=metadata, - colormaps=colormaps, - default_cmap="RdBu_r", - python_interface=True, - leapmotion=True, - layout=layout, - subjects=subjectjs, - viewopts=json.dumps(my_viewopts), - title=title, - **kwargs) - #overlays_visible=json.dumps(overlays_visible), - #labels_visible=json.dumps(labels_visible), - #**viewopts) + generated = html.generate( + data=metadata, + colormaps=colormaps, + default_cmap="RdBu_r", + python_interface=True, + leapmotion=True, + layout=layout, + subjects=subjectjs, + viewopts=json.dumps(my_viewopts), + title=title, + **kwargs, + ) + # overlays_visible=json.dumps(overlays_visible), + # labels_visible=json.dumps(labels_visible), + # **viewopts) self.write(generated) def post(self): @@ -554,13 +577,13 @@ def post(self): data = png svgfile.write(data) - P = ParamSpec('P') + P = ParamSpec("P") class JSMixer(serve.JSProxy[P]): @property def view_props(self) -> list[str]: - """An enumerated list of settable properties for views. - There may be a way to get this from the javascript object, + """An enumerated list of settable properties for views. + There may be a way to get this from the javascript object, but I (ML) don't know how. There may be additional properties we want to set in views @@ -572,15 +595,24 @@ def view_props(self) -> list[str]: 'volume_vis', 'frame', 'slices'] """ camera = getattr(self.ui, "camera") - _camera_props = ['camera.%s' % k for k in camera._controls.attrs.keys()] + _camera_props = ["camera.%s" % k for k in camera._controls.attrs.keys()] surface = getattr(self.ui, "surface") _subject = list(surface._folders.attrs.keys())[0] _surface = getattr(surface, _subject) - _surface_props = ['surface.{subject}.%s'%k for k in _surface._controls.attrs.keys()] - _curvature_props = ['surface.{subject}.curvature.brightness', - 'surface.{subject}.curvature.contrast', - 'surface.{subject}.curvature.smoothness'] - return _camera_props + _surface_props + _curvature_props + _surface_props = [ + "surface.{subject}.%s" % k for k in _surface._controls.attrs.keys() + ] + _curvature_props = [ + "surface.{subject}.curvature.brightness", + "surface.{subject}.curvature.contrast", + "surface.{subject}.curvature.smoothness", + ] + _contour_props = [ + "surface.{subject}.contours.mode", + "surface.{subject}.contours.threshold", + "surface.{subject}.contours.overlay", + ] + return _camera_props + _surface_props + _curvature_props + _contour_props def _set_view(self, **kwargs): """Low-level command: sets view parameters in the current viewer @@ -588,24 +620,36 @@ def _set_view(self, **kwargs): Sets each the state of each keyword argument provided. View parameters that can be set include all parameters in the data.gui in the html view. + Contour-related parameters: + surface.{subject}.contours.mode : int + 0=off, 1=contours only, 2=contours+fill + surface.{subject}.contours.threshold : float + Edge detection threshold (0.001-0.5) + surface.{subject}.contours.overlay : str + Dataset name to use as contour overlay, or "none" + """ # Set unfolding level first, as it interacts with other arguments assert isinstance(self.ui, serve.JSProxy) surface: serve.JSProxy[P] = getattr(self.ui, "surface") subject_list = cast(serve.JSProxy[P], surface._folders).attrs.keys() - # Better to only self.view_props once; it interacts with javascript, + # Better to only self.view_props once; it interacts with javascript, # don't want to do that too often, it leads to glitches. vw_props = copy.copy(self.view_props) for subject in subject_list: - if 'surface.{subject}.unfold' in kwargs: - unfold = kwargs.pop('surface.{subject}.unfold') - self.ui.set('surface.{subject}.unfold'.format(subject=subject), unfold) + if "surface.{subject}.unfold" in kwargs: + unfold = kwargs.pop("surface.{subject}.unfold") + self.ui.set( + "surface.{subject}.unfold".format(subject=subject), unfold + ) for k, v in kwargs.items(): if not k in vw_props: - print('Unknown parameter %s!'%k) + print("Unknown parameter %s!" % k) continue else: - self.ui.set(k.format(subject=subject) if '{subject}' in k else k, v) + self.ui.set( + k.format(subject=subject) if "{subject}" in k else k, v + ) # Wait for webgl. Wait for it. .... WAAAAAIIIT. time.sleep(0.03) @@ -621,7 +665,7 @@ def _capture_view(self, frame_time=None): ---------- frame_time : scalar time (in seconds) to specify for this frame. - + Notes ----- If multiple subjects are present, only retrieves view for first subject. @@ -630,16 +674,18 @@ def _capture_view(self, frame_time=None): subject = list(self.ui.surface._folders.attrs.keys())[0] for p in self.view_props: try: - view[p] = self.ui.get(p.format(subject=subject) if '{subject}' in p else p)[0] + view[p] = self.ui.get( + p.format(subject=subject) if "{subject}" in p else p + )[0] # Wait for webgl. time.sleep(0.03) except Exception as err: # TO DO: Fix this hack with an error class in serve.py & catch it here - print(err) #msg = "Cannot read property 'undefined'" - #if err.message[:len(msg)] != msg: + print(err) # msg = "Cannot read property 'undefined'" + # if err.message[:len(msg)] != msg: # raise err if frame_time is not None: - view['time'] = frame_time + view["time"] = frame_time return view def save_view(self, subject, name, is_overwrite=False): @@ -683,12 +729,14 @@ def get_view(self, subject, name): def addData(self, **kwargs): Proxy = serve.JSProxy(self.send, "window.viewers.addData") - new_meta, new_ims = _convert_dataset(Dataset(**kwargs), path='/data/', fmt='%s_%d.png') + new_meta, new_ims = _convert_dataset( + Dataset(**kwargs), path="/data/", fmt="%s_%d.png" + ) metadata.update(new_meta) images.update(new_ims) return Proxy(metadata) - def getImage(self, filename: str, size: tuple[int, int]=(1920, 1080)): + def getImage(self, filename: str, size: tuple[int, int] = (1920, 1080)): """Saves currently displayed view to a .png image file Parameters @@ -702,8 +750,15 @@ def getImage(self, filename: str, size: tuple[int, int]=(1920, 1080)): Proxy = serve.JSProxy(self.send, "window.viewer.getImage") return Proxy(size[0], size[1], "mixer.html") - def makeMovie(self, animation, filename="brainmovie%07d.png", offset=0, - fps=30, size=(1920, 1080), interpolation="linear"): + def makeMovie( + self, + animation, + filename="brainmovie%07d.png", + offset=0, + fps=30, + size=(1920, 1080), + interpolation="linear", + ): """Renders movie frames for animation of mesh movement Makes an animation (for example, a transition between inflated and @@ -754,39 +809,47 @@ def makeMovie(self, animation, filename="brainmovie%07d.png", offset=0, # anim is a list of transitions between keyframes anim = [] setfunc = self.ui.set - for f in sorted(animation, key=lambda x:x['idx']): - if f['idx'] == 0: - setfunc(f['state'], f['value']) - state[f['state']] = dict(idx=f['idx'], val=f['value']) + for f in sorted(animation, key=lambda x: x["idx"]): + if f["idx"] == 0: + setfunc(f["state"], f["value"]) + state[f["state"]] = dict(idx=f["idx"], val=f["value"]) else: - if f['state'] not in state: - state[f['state']] = dict(idx=0, val=self.getState(f['state'])[0]) - start = dict(idx=state[f['state']]['idx'], - state=f['state'], - value=state[f['state']]['val']) - end = dict(idx=f['idx'], state=f['state'], value=f['value']) - state[f['state']]['idx'] = f['idx'] - state[f['state']]['val'] = f['value'] - if start['value'] != end['value']: + if f["state"] not in state: + state[f["state"]] = dict( + idx=0, val=self.getState(f["state"])[0] + ) + start = dict( + idx=state[f["state"]]["idx"], + state=f["state"], + value=state[f["state"]]["val"], + ) + end = dict(idx=f["idx"], state=f["state"], value=f["value"]) + state[f["state"]]["idx"] = f["idx"] + state[f["state"]]["val"] = f["value"] + if start["value"] != end["value"]: anim.append((start, end)) - for i, sec in enumerate(np.arange(0, anim[-1][1]['idx']+1./fps, 1./fps)): + for i, sec in enumerate( + np.arange(0, anim[-1][1]["idx"] + 1.0 / fps, 1.0 / fps) + ): for start, end in anim: - if start['idx'] < sec <= end['idx']: - idx = (sec - start['idx']) / float(end['idx'] - start['idx']) - if start['state'] == 'frame': - func = mixes['linear'] + if start["idx"] < sec <= end["idx"]: + idx = (sec - start["idx"]) / float(end["idx"] - start["idx"]) + if start["state"] == "frame": + func = mixes["linear"] else: func = mixes[interpolation] - val = func(np.array(start['value']), np.array(end['value']), idx) + val = func( + np.array(start["value"]), np.array(end["value"]), idx + ) if isinstance(val, np.ndarray): - setfunc(start['state'], val.ravel().tolist()) + setfunc(start["state"], val.ravel().tolist()) else: - setfunc(start['state'], val) - self.getImage(filename%(i+offset), size=size) + setfunc(start["state"], val) + self.getImage(filename % (i + offset), size=size) - def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'): + def _get_anim_seq(self, keyframes, fps=30, interpolation="linear"): """Convert a list of keyframes to a list of EVERY frame in an animation. Utility function called by make_movie; separated out so that individual @@ -798,23 +861,23 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'): fr = 0 a = np.array func = mixes[interpolation] - #skip_props = ['surface.{subject}.right', 'surface.{subject}.left', ] #'projection', + # skip_props = ['surface.{subject}.right', 'surface.{subject}.left', ] #'projection', # Get keyframes - keyframes = sorted(keyframes, key=lambda x:x['time']) + keyframes = sorted(keyframes, key=lambda x: x["time"]) # Normalize all time to frame rate - fs = 1./fps + fs = 1.0 / fps for k in range(len(keyframes)): - t = keyframes[k]['time'] - t = np.round(t/fs)*fs - keyframes[k]['time'] = t + t = keyframes[k]["time"] + t = np.round(t / fs) * fs + keyframes[k]["time"] = t allframes = [] for start, end in zip(keyframes[:-1], keyframes[1:]): - t0 = start['time'] - t1 = end['time'] - tdif = float(t1-t0) + t0 = start["time"] + t1 = end["time"] + tdif = float(t1 - t0) # Check whether to continue frame sequence to endpoint - use_endpoint = keyframes[-1]==end - nvalues = np.round(tdif/fs).astype(int) + use_endpoint = keyframes[-1] == end + nvalues = np.round(tdif / fs).astype(int) if use_endpoint: nvalues += 1 fr_time = np.linspace(0, 1, nvalues, endpoint=use_endpoint) @@ -822,9 +885,13 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'): for t in fr_time: frame = {} for prop in start.keys(): - if prop=='time': + if prop == "time": continue - if (start[prop] is None) or (start[prop] == end[prop]) or isinstance(start[prop], (bool, str)): + if ( + (start[prop] is None) + or (start[prop] == end[prop]) + or isinstance(start[prop], (bool, str)) + ): frame[prop] = start[prop] continue val = func(a(start[prop]), a(end[prop]), t) @@ -835,9 +902,18 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'): allframes.append(frame) return allframes - def make_movie_views(self, animation, filename="brainmovie%07d.png", - offset=0, fps=30, size=(1920, 1080), alpha=1, frame_sleep=0.05, - frame_start=0, interpolation="linear"): + def make_movie_views( + self, + animation, + filename="brainmovie%07d.png", + offset=0, + fps=30, + size=(1920, 1080), + alpha=1, + frame_sleep=0.05, + frame_start=0, + interpolation="linear", + ): """Renders movie frames for animation of mesh movement Makes an animation (for example, a transition between inflated and @@ -894,7 +970,7 @@ def make_movie_views(self, animation, filename="brainmovie%07d.png", for fr, frame in enumerate(allframes[frame_start:], frame_start): self._set_view(**frame) time.sleep(frame_sleep) - self.getImage(filename%(fr+offset+1), size=size) + self.getImage(filename % (fr + offset + 1), size=size) time.sleep(frame_sleep) class PickerHandler(web.RequestHandler): @@ -907,7 +983,9 @@ def get(self): parts = voxel_arg.split(",") if len(parts) != 3: self.set_status(400) - self.finish("Invalid 'voxel' query parameter: expected 3 comma-separated integers") + self.finish( + "Invalid 'voxel' query parameter: expected 3 comma-separated integers" + ) return try: voxel: tuple[int, int, int] = tuple(int(i) for i in parts) @@ -921,6 +999,7 @@ def get(self): class WebApp(serve.WebApp): disconnect_on_close = autoclose + def get_client(self): self.connect.wait() self.connect.clear() @@ -932,18 +1011,22 @@ def get_local_client(self): if port is None: port = random.randint(1024, 65536) - server = WebApp([(r'/ctm/(.*)', CTMHandler), - (r'/data/(.*)', DataHandler), - (r'/stim/(.*)', StimHandler), - (r'/mixer.html', MixerHandler), - (r'/picker', PickerHandler), - (r'/', MixerHandler), - (r'/static/(.*)', StaticHandler)], - port) + server = WebApp( + [ + (r"/ctm/(.*)", CTMHandler), + (r"/data/(.*)", DataHandler), + (r"/stim/(.*)", StimHandler), + (r"/mixer.html", MixerHandler), + (r"/picker", PickerHandler), + (r"/", MixerHandler), + (r"/static/(.*)", StaticHandler), + ], + port, + ) server.start() - print("Started server on port %d"%server.port) - url = "http://%s%s:%d/mixer.html"%(serve.hostname, domain_name, server.port) + print("Started server on port %d" % server.port) + url = "http://%s%s:%d/mixer.html" % (serve.hostname, domain_name, server.port) if open_browser: webbrowser.open(url) client = server.get_client() @@ -952,7 +1035,10 @@ def get_local_client(self): elif display_url: try: from IPython.display import HTML, display - display(HTML('Open viewer: {0}'.format(url))) + + display( + HTML('Open viewer: {0}'.format(url)) + ) except: pass diff --git a/examples/quickflat/plot_contours.py b/examples/quickflat/plot_contours.py new file mode 100644 index 000000000..4eefe5931 --- /dev/null +++ b/examples/quickflat/plot_contours.py @@ -0,0 +1,108 @@ +""" +=============================== +Plot parcellation contour lines +=============================== + +Parcellation contour lines can be overlaid on top of data to delineate +region boundaries without obscuring the underlying activation map. + +This is useful when you want to show, for example, fMRI activation data +with anatomical or functional parcellation borders drawn on top. + +The ``with_contours`` parameter accepts a :class:`cortex.Vertex` (or any +Dataview) whose label boundaries will be drawn as contour lines. You can +customise the line color with ``contour_linecolor`` and the line width +with ``contour_linewidth``. +""" + +import cortex +import matplotlib.pyplot as plt +import numpy as np +from collections import deque + +np.random.seed(1234) + +subject = "S1" +n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0] + +############################################################################### +# Create a random parcellation +# ---------------------------- +# We generate a parcellation by growing 30 random seed vertices across the +# mesh using breadth-first search. Each seed becomes a parcel. + +n_parcels = 30 +_, polys = cortex.db.get_surf(subject, "fiducial", merge=True) +neighbors = cortex.utils._get_neighbors_dict(polys) + +parcellation = np.zeros(n_verts, dtype=float) +seeds = np.random.choice(n_verts, n_parcels, replace=False) +for i, s in enumerate(seeds, 1): + parcellation[s] = float(i) + +queue = deque(seeds.tolist()) +while queue: + v = queue.popleft() + for nb in neighbors.get(v, []): + if nb < n_verts and parcellation[nb] == 0: + parcellation[nb] = parcellation[v] + queue.append(nb) + +# Create Vertex objects +parc_vertex = cortex.Vertex(parcellation, subject, cmap="Set1", vmin=0, vmax=n_parcels) +activation = cortex.Vertex( + np.random.randn(n_verts), subject, cmap="RdBu_r", vmin=-2, vmax=2 +) + +############################################################################### +# Activation data with parcellation contours +# ------------------------------------------- +# Pass a Dataview to ``with_contours`` to draw its label boundaries on top +# of the primary data. + +fig = cortex.quickshow( + activation, + with_contours=parc_vertex, + with_curvature=True, + with_rois=False, + with_colorbar=True, + height=1024, +) +fig.suptitle("Activation + parcellation contours", fontsize=14) +plt.show() + +############################################################################### +# Custom contour color and width +# ------------------------------ +# Use ``contour_linecolor`` (RGBA tuple) and ``contour_linewidth`` (pixels) +# to customise the contour appearance. + +fig = cortex.quickshow( + activation, + with_contours=parc_vertex, + contour_linecolor=(1, 0, 0, 1), + contour_linewidth=3, + with_curvature=True, + with_rois=False, + with_colorbar=False, + height=1024, +) +fig.suptitle("Red thick contours", fontsize=14) +plt.show() + +############################################################################### +# Parcellation contours on curvature +# ----------------------------------- +# You can also overlay the contours of the parcellation on its own data +# to see both the filled colors and the borders. + +fig = cortex.quickshow( + parc_vertex, + with_contours=parc_vertex, + with_curvature=True, + with_rois=False, + with_colorbar=False, + height=1024, +) +fig.suptitle("Parcellation with contour borders", fontsize=14) +plt.show() diff --git a/examples/webgl/plot_contours_headless.py b/examples/webgl/plot_contours_headless.py new file mode 100644 index 000000000..fbbd321de --- /dev/null +++ b/examples/webgl/plot_contours_headless.py @@ -0,0 +1,107 @@ +""" +=============================================== +Plot parcellation contours on 3D brain (headless) +=============================================== + +The WebGL viewer supports contour rendering of parcellation borders on +the 3D cortical surface. ``cortex.export.save_3d_views`` accepts a +``contour_overlay`` parameter — pass a ``cortex.Vertex`` with parcellation +labels and the function automatically bundles it with the primary data, +enabling contour borders in the rendered views. + +Available contour modes (``contour_mode``): + +- ``"contours"``: borders only on curvature +- ``"contours+fill"``: data with solid-colour borders (default) +- ``"colored"``: borders coloured by the overlay's colormap +- ``"colored+fill"``: data with colormap-coloured borders + +Prerequisites +------------- +Install Playwright and download the bundled Chromium binary once:: + + pip install playwright + playwright install chromium + +""" + +import os +import tempfile +from collections import deque + +import matplotlib.pyplot as plt +import numpy as np + +import cortex +import cortex.export + +np.random.seed(1234) + +subject = "S1" +n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0] + +############################################################################### +# Create a random parcellation +# ---------------------------- +# Grow 30 random seed vertices across the mesh using breadth-first search. + +n_parcels = 30 +_, polys = cortex.db.get_surf(subject, "fiducial", merge=True) +neighbors = cortex.utils._get_neighbors_dict(polys) + +parcellation = np.zeros(n_verts, dtype=float) +seeds = np.random.choice(n_verts, n_parcels, replace=False) +for i, s in enumerate(seeds, 1): + parcellation[s] = float(i) + +queue = deque(seeds.tolist()) +while queue: + v = queue.popleft() + for nb in neighbors.get(v, []): + if nb < n_verts and parcellation[nb] == 0: + parcellation[nb] = parcellation[v] + queue.append(nb) + +############################################################################### +# Create data and parcellation Vertex objects +# -------------------------------------------- + +activation = cortex.Vertex( + np.random.randn(n_verts), subject, cmap="RdBu_r", vmin=-2, vmax=2 +) +parc_vertex = cortex.Vertex(parcellation, subject, cmap="Set1", vmin=0, vmax=n_parcels) + +############################################################################### +# Render with parcellation contour overlay +# ------------------------------------------ +# Pass the parcellation ``Vertex`` directly as ``contour_overlay``. +# The function wraps both into a Dataset automatically. +# The default ``contour_mode="contours+fill"`` draws black borders. + +base_name = os.path.join(tempfile.mkdtemp(), "contour") + +fnames = cortex.export.save_3d_views( + activation, + base_name=base_name, + list_angles=["left"], + list_surfaces=["inflated"], + viewer_params=dict(labels_visible=[], overlays_visible=[]), + size=(1920 * 2, 1080 * 2), + trim=True, + headless=True, + contour_overlay=parc_vertex, +) + +for fname in fnames: + img = plt.imread(fname) + aspect = img.shape[0] / img.shape[1] + fig, ax = plt.subplots(figsize=(10, 10 * aspect)) + ax.imshow(img) + ax.axis("off") + ax.set_title( + "Activation + parcellation contours (inflated, left)", + fontsize=14, + fontweight="bold", + ) + fig.subplots_adjust(left=0, right=1, top=0.92, bottom=0) + plt.show()