diff --git a/cortex/export/panels.py b/cortex/export/panels.py
index 8335a33fd..a45cb31d6 100644
--- a/cortex/export/panels.py
+++ b/cortex/export/panels.py
@@ -52,6 +52,8 @@ def plot_panels(
interpolation: str = "nearest",
layers: int = 1,
headless: bool = False,
+ contour_overlay: Optional[Union[str, Dataview]] = None,
+ contour_mode: str = "contours+fill",
) -> Figure:
"""Plot on the same figure a number of views, as defined by a list of panel
specifications.
@@ -110,6 +112,15 @@ def plot_panels(
Software WebGL (SwiftShader) is used, so no GPU or display server is
needed. (Default: False)
+ contour_overlay : Dataview, str, or None
+ Parcellation data whose borders will be drawn as contour lines.
+ Can be a Vertex/Dataview (automatically bundled into a Dataset),
+ or a string naming a view within an existing Dataset. (Default: None)
+
+ contour_mode : str
+ Contour rendering mode. Options: "contours", "contours+fill",
+ "colored", "colored+fill". (Default: "contours+fill")
+
Returns
-------
fig : matplotlib.Figure
@@ -150,6 +161,8 @@ def plot_panels(
interpolation=interpolation,
layers=layers,
headless=headless,
+ contour_overlay=contour_overlay,
+ contour_mode=contour_mode,
)
fig = plt.figure(figsize=figsize)
diff --git a/cortex/export/save_views.py b/cortex/export/save_views.py
index 2ea757510..f884b965c 100644
--- a/cortex/export/save_views.py
+++ b/cortex/export/save_views.py
@@ -1,7 +1,7 @@
import contextlib
import os
import time
-from typing import Any, Mapping, Sequence, TypedDict, Union
+from typing import Any, Mapping, Optional, Sequence, TypedDict, Union
import cortex
@@ -38,6 +38,8 @@ def save_3d_views(
trim: bool = True,
sleep: float = 10,
headless: bool = False,
+ contour_overlay: Optional[Union[str, Dataview]] = None,
+ contour_mode: str = "contours+fill",
) -> list[str]:
"""Saves 3D views of `volume` under multiple specifications.
@@ -47,13 +49,13 @@ def save_3d_views(
Parameters
----------
- volume: pycortex.Volume or pycortex.Vertex object
+ volume : pycortex.Volume, pycortex.Vertex, or pycortex.Dataset
Data to be displayed.
- base_name: str
+ base_name : str
Base name for images.
- list_angles: list of (str or dict)
+ list_angles : list of (str or dict)
Views to be used. Should be of length one, or of the same length as
`list_surfaces`. Choices are:
'left', 'right', 'front', 'back', 'top', 'bottom', 'flatmap',
@@ -61,33 +63,33 @@ def save_3d_views(
or tuple of (view_name, custom dictionary of parameters).
See `angle_view_params` in this file for parameter dict examples.
- list_surfaces: list of (str or dict)
+ list_surfaces : list of (str or dict)
Surfaces to be used. Should be of length one, or of the same length as
`list_angles`. Choices are:
'inflated', 'flatmap', 'fiducial', 'inflated_cut',
or a custom dictionary of parameters.
- viewer_params: dict
+ viewer_params : dict
Parameters passed to the viewer.
- interpolation: str
+ interpolation : str
Interpolation used to visualize the data. Possible choices are "nearest",
"trilinear". (Default: "nearest").
- layers: int
+ layers : int
Number of layers between the white and pial surfaces to average prior to
plotting the data. (Default: 1).
- size: tuple of int
+ size : tuple of int
Size of produced image (before trimming).
- trim: bool
+ trim : bool
Whether to trim the white borders of the image.
- sleep: float > 0
+ sleep : float > 0
Time in seconds, to let the viewer open.
- headless: bool
+ headless : bool
If True, render using a headless Chromium browser via Playwright instead
of requiring the user to manually open a browser window. This allows
the function to run fully autonomously without any user interaction.
@@ -96,14 +98,39 @@ def save_3d_views(
Software WebGL (SwiftShader) is used, so no GPU or display server is
needed. (Default: False)
+ contour_overlay : Dataview, str, or None
+ Parcellation data whose borders will be drawn as contour lines.
+ Can be a Vertex/Dataview object (automatically bundled into a Dataset
+ with ``volume``), or a string naming a view within an existing Dataset
+ passed as ``volume``. (Default: None)
+
+ contour_mode : str
+ Contour rendering mode when ``contour_overlay`` is set.
+ Options: "contours", "contours+fill", "colored", "colored+fill".
+ (Default: "contours+fill")
+
Returns
-------
- file_names: list of str
+ file_names : list of str
Image paths.
"""
msg = "list_angles and list_surfaces should have the same length."
assert len(list_angles) == len(list_surfaces), msg
+ # If contour_overlay is a Dataview, bundle volume + overlay into a Dataset.
+ # Preserve the original volume reference for isinstance checks below.
+ _contour_overlay_name = None
+ _original_volume = volume
+ if contour_overlay is not None:
+ if isinstance(contour_overlay, str):
+ _contour_overlay_name = contour_overlay
+ else:
+ # contour_overlay is a Dataview — wrap into Dataset
+ _contour_overlay_name = "__contour_overlay__"
+ volume = cortex.Dataset(
+ data=volume, **{_contour_overlay_name: contour_overlay}
+ )
+
# Create viewer — use a proper context manager so that cleanup always
# runs, even if an exception occurs during rendering.
if headless:
@@ -117,8 +144,39 @@ def save_3d_views(
# Wait for the viewer to be loaded
time.sleep(sleep)
- # Add interpolation and layers params only if we have a volume
- if isinstance(volume, (cortex.Volume, cortex.Volume2D, cortex.VolumeRGB)):
+ # Set up contour overlay if requested
+ if _contour_overlay_name is not None:
+ _contour_mode_map = {
+ "contours": 1,
+ "contours+fill": 2,
+ "colored": 3,
+ "colored+fill": 4,
+ }
+ if contour_mode not in _contour_mode_map:
+ raise ValueError(
+ f"Unknown contour_mode {contour_mode!r}. "
+ f"Valid options: {list(_contour_mode_map.keys())}"
+ )
+ _contour_mode_int = _contour_mode_map[contour_mode]
+ handle._set_view(
+ **{
+ "surface.{subject}.contours.overlay": _contour_overlay_name,
+ }
+ )
+ # Wait for overlay data to load
+ time.sleep(sleep)
+ handle._set_view(
+ **{
+ "surface.{subject}.contours.mode": _contour_mode_int,
+ }
+ )
+ time.sleep(1)
+
+ # Add interpolation and layers params only if the primary data is a volume.
+ # Use _original_volume (before Dataset wrapping) for the type check.
+ if isinstance(
+ _original_volume, (cortex.Volume, cortex.Volume2D, cortex.VolumeRGB)
+ ):
interpolation_params = {
"surface.{subject}.sampler": interpolation,
"surface.{subject}.layers": layers,
@@ -126,7 +184,13 @@ def save_3d_views(
else:
interpolation_params = dict()
- has_flatmap = hasattr(getattr(cortex.db, volume.subject).surfaces, "flat")
+ # Get subject name — handle both Dataview and Dataset
+ if hasattr(_original_volume, "subject"):
+ _subject = _original_volume.subject
+ else:
+ # Dataset: get subject from first view
+ _subject = next(iter(volume))[1].subject
+ has_flatmap = hasattr(getattr(cortex.db, _subject).surfaces, "flat")
file_names: list[str] = []
for view, surface in zip(list_angles, list_surfaces):
if isinstance(view, str):
@@ -163,7 +227,7 @@ def save_3d_views(
# wait for the view to have changed
for _ in range(100):
for k, v in this_view_params.items():
- k = k.format(subject=volume.subject) if "{subject}" in k else k
+ k = k.format(subject=_subject) if "{subject}" in k else k
if handle.ui.get(k)[0] != v:
print("waiting for", k, handle.ui.get(k)[0], "->", v)
time.sleep(0.1)
diff --git a/cortex/quickflat/composite.py b/cortex/quickflat/composite.py
index b12104911..880b303e0 100644
--- a/cortex/quickflat/composite.py
+++ b/cortex/quickflat/composite.py
@@ -3,16 +3,39 @@
from .. import dataset
from ..database import db
from ..options import config
-from .utils import _get_height, _get_extents, _convert_svg_kwargs, _get_images, _parse_defaults
-from .utils import make_flatmap_image, _make_hatch_image, _get_fig_and_ax, get_flatmask, get_flatcache
+from .utils import (
+ _get_height,
+ _get_extents,
+ _convert_svg_kwargs,
+ _get_images,
+ _parse_defaults,
+)
+from .utils import (
+ make_flatmap_image,
+ _make_hatch_image,
+ _get_fig_and_ax,
+ get_flatmask,
+ get_flatcache,
+)
""" --- Individual compositing functions --- """
-def add_curvature(fig, dataview, extents=None, height=None, threshold=True, contrast=None,
- brightness=None, smooth=None, cmap='gray', recache=False, curvature_lims=0.5,
- legacy_mode=False):
+def add_curvature(
+ fig,
+ dataview,
+ extents=None,
+ height=None,
+ threshold=True,
+ contrast=None,
+ brightness=None,
+ smooth=None,
+ cmap="gray",
+ recache=False,
+ curvature_lims=0.5,
+ legacy_mode=False,
+):
"""Add curvature layer to figure
Parameters
@@ -22,13 +45,13 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
extents : array-like
- 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
+ 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
extents of images already present in figure.
height : scalar
- Height of image. None defaults to height of images already present in figure.
+ Height of image. None defaults to height of images already present in figure.
threshold : boolean
Whether to apply a threshold to the curvature values to create a binary curvature image
- (one shade for positive curvature, one shade for negative). `None` defaults to value
+ (one shade for positive curvature, one shade for negative). `None` defaults to value
specified in the config file
contrast : float, [0-1] or None
Contrast of curvature image. 1 is maximal contrast (given brightness). If brightness is 0.5
@@ -59,11 +82,12 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
"""
from matplotlib.colors import Normalize
+
if height is None:
height = _get_height(fig)
# Get curvature map as image
- default_smoothing = config.get('curvature', 'smooth')
- if default_smoothing.lower()=='none':
+ default_smoothing = config.get("curvature", "smooth")
+ if default_smoothing.lower() == "none":
default_smoothing = None
else:
default_smoothing = np.float_(default_smoothing)
@@ -83,10 +107,16 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
norm = Normalize(vmin=-0.5, vmax=0.5)
curv_im = norm(curv)
# Option to use thresholded curvature
- default_threshold = config.get('curvature','threshold').lower() in ('true', 't', '1', 'y', 'yes')
+ default_threshold = config.get("curvature", "threshold").lower() in (
+ "true",
+ "t",
+ "1",
+ "y",
+ "yes",
+ )
use_threshold_curvature = default_threshold if threshold is None else threshold
if legacy_mode and use_threshold_curvature:
- curvT = (curv>0).astype(np.float32)
+ curvT = (curv > 0).astype(np.float32)
curvT[np.isnan(curv)] = np.nan
curv = curvT
if isinstance(curvature_lims, (list, tuple)):
@@ -102,26 +132,38 @@ def add_curvature(fig, dataview, extents=None, height=None, threshold=True, cont
curv_im[np.isnan(curv)] = np.nan
# Get defaults for brightness, contrast
if brightness is None:
- brightness = float(config.get('curvature', 'brightness'))
+ brightness = float(config.get("curvature", "brightness"))
if contrast is None:
- contrast = float(config.get('curvature', 'contrast'))
+ contrast = float(config.get("curvature", "contrast"))
# Scale and shift curvature image
curv_im = (curv_im - 0.5) * contrast + brightness
if extents is None:
extents = _get_extents(fig)
_, ax = _get_fig_and_ax(fig)
- cvimg = ax.imshow(curv_im,
- aspect='equal',
- extent=extents,
- cmap=cmap,
- vmin=0,
- vmax=1,
- label='curvature',
- zorder=0)
+ cvimg = ax.imshow(
+ curv_im,
+ aspect="equal",
+ extent=extents,
+ cmap=cmap,
+ vmin=0,
+ vmax=1,
+ label="curvature",
+ zorder=0,
+ )
return cvimg
-def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
- sampler='nearest', recache=False, nanmean=False):
+
+def add_data(
+ fig,
+ braindata,
+ height=1024,
+ thick=32,
+ depth=0.5,
+ pixelwise=True,
+ sampler="nearest",
+ recache=False,
+ nanmean=False,
+):
"""Add data to quickflat plot
Parameters
@@ -157,24 +199,44 @@ def add_data(fig, braindata, height=1024, thick=32, depth=0.5, pixelwise=True,
if not isinstance(dataview, dataset.Dataview):
# Unclear what this means. Clarify error in terms of pycortex classes
# (please provide a [cortex.dataset.Dataview or whatever] instance)
- raise TypeError('Please provide a Dataview, not a Dataset')
+ raise TypeError("Please provide a Dataview, not a Dataset")
# Generate image (2D array, maybe 3D array)
- im, extents = make_flatmap_image(dataview, recache=recache, pixelwise=pixelwise, sampler=sampler,
- height=height, thick=thick, depth=depth, nanmean=nanmean)
+ im, extents = make_flatmap_image(
+ dataview,
+ recache=recache,
+ pixelwise=pixelwise,
+ sampler=sampler,
+ height=height,
+ thick=thick,
+ depth=depth,
+ nanmean=nanmean,
+ )
# Check whether dataview has a cmap instance
cmapdict = dataview.get_cmapdict()
# Plot
_, ax = _get_fig_and_ax(fig)
- img = ax.imshow(im,
- aspect='equal',
- extent=extents,
- label='data',
- zorder=1,
- interpolation="nearest",
- **cmapdict)
+ img = ax.imshow(
+ im,
+ aspect="equal",
+ extent=extents,
+ label="data",
+ zorder=1,
+ interpolation="nearest",
+ **cmapdict,
+ )
return img, extents
-def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_list=None, overlay_file=None, **kwargs):
+
+def add_rois(
+ fig,
+ dataview,
+ extents=None,
+ height=None,
+ with_labels=True,
+ roi_list=None,
+ overlay_file=None,
+ **kwargs,
+):
"""Add ROIs layer to a figure
NOTE: zorder for rois is 3
@@ -186,15 +248,15 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
extents : array-like
- 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
+ 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
extents of images already present in figure.
- height : scalar
- Height of image. None defaults to height of images already present in figure.
+ height : scalar
+ Height of image. None defaults to height of images already present in figure.
with_labels : bool
Whether to display text labels on ROIs
- roi_list :
+ roi_list :
- kwargs :
+ kwargs :
Returns
-------
@@ -204,23 +266,36 @@ def add_rois(fig, dataview, extents=None, height=None, with_labels=True, roi_lis
if extents is None:
extents = _get_extents(fig)
if height is None:
- height = _get_height(fig)
+ height = _get_height(fig)
svgobject = db.get_overlay(dataview.subject, overlay_file=overlay_file)
svg_kws = _convert_svg_kwargs(kwargs)
- layer_kws = _parse_defaults('rois_paths')
+ layer_kws = _parse_defaults("rois_paths")
layer_kws.update(svg_kws)
- im = svgobject.get_texture('rois', height, labels=with_labels, shape_list=roi_list, **layer_kws)
+ im = svgobject.get_texture(
+ "rois", height, labels=with_labels, shape_list=roi_list, **layer_kws
+ )
_, ax = _get_fig_and_ax(fig)
- img = ax.imshow(im,
- aspect='equal',
- interpolation='bicubic',
- extent=extents,
- label='rois',
- zorder=1000)
+ img = ax.imshow(
+ im,
+ aspect="equal",
+ interpolation="bicubic",
+ extent=extents,
+ label="rois",
+ zorder=1000,
+ )
return img
-def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_list=None, overlay_file=None, **kwargs):
+def add_sulci(
+ fig,
+ dataview,
+ extents=None,
+ height=None,
+ with_labels=True,
+ sulci_list=None,
+ overlay_file=None,
+ **kwargs,
+):
"""Add sulci layer to figure
Parameters
@@ -230,10 +305,10 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_
dataview : cortex.Dataview object
dataview containing data to be plotted, subject (surface identifier), and transform.
extents : array-like
- 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
+ 4 values for [Left, Right, Top, Bottom] extents of image plotted. None defaults to
extents of images already present in figure.
height : scalar
- Height of image. None defaults to height of images already present in figure.
+ Height of image. None defaults to height of images already present in figure.
with_labels : bool
Whether to display text labels for sulci
sulci_list : list
@@ -252,23 +327,35 @@ def add_sulci(fig, dataview, extents=None, height=None, with_labels=True, sulci_
"""
svgobject = db.get_overlay(dataview.subject, overlay_file=overlay_file)
svg_kws = _convert_svg_kwargs(kwargs)
- layer_kws = _parse_defaults('sulci_paths')
+ layer_kws = _parse_defaults("sulci_paths")
layer_kws.update(svg_kws)
- sulc = svgobject.get_texture('sulci', height, labels=with_labels, shape_list=sulci_list, **layer_kws)
+ sulc = svgobject.get_texture(
+ "sulci", height, labels=with_labels, shape_list=sulci_list, **layer_kws
+ )
if extents is None:
extents = _get_extents(fig)
_, ax = _get_fig_and_ax(fig)
- img = ax.imshow(sulc,
- aspect='equal',
- interpolation='bicubic',
- extent=extents,
- label='sulci',
- zorder=5)
+ img = ax.imshow(
+ sulc,
+ aspect="equal",
+ interpolation="bicubic",
+ extent=extents,
+ label="sulci",
+ zorder=5,
+ )
return img
-def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
- hatch_color=(0, 0, 0), sampler='nearest', recache=False):
+def add_hatch(
+ fig,
+ hatch_data,
+ extents=None,
+ height=None,
+ hatch_space=4,
+ hatch_color=(0, 0, 0),
+ sampler="nearest",
+ recache=False,
+):
"""Add hatching to figure at locations specified in hatch_data
Parameters
@@ -279,20 +366,20 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
cortex.Volume object created from data scaled from 0-1; locations with values of 1 will
have hatching overlaid on them in the resulting image.
extents : array-like
- 4 values for [Left, Right, Top, Bottom] extents of image plotted. If None, defaults to
+ 4 values for [Left, Right, Top, Bottom] extents of image plotted. If None, defaults to
extents of images already present in figure.
- height : scalar
- Height of image. if None, defaults to height of images already present in figure.
- hatch_space : scalar
+ height : scalar
+ Height of image. if None, defaults to height of images already present in figure.
+ hatch_space : scalar
Spacing between hatch lines, in pixels
hatch_color : 3-tuple
(R, G, B) tuple for color of hatching. Values for R,G,B should be 0-1
sampler : str
- Name of sampling function used to sample underlying volume data. Options include
+ Name of sampling function used to sample underlying volume data. Options include
'trilinear','nearest','lanczos'; see functions in cortex.mapper.samplers.py for all options
recache : boolean
Whether or not to recache intermediate files. Takes longer to plot this way, potentially
- resolves some errors.
+ resolves some errors.
Returns
-------
@@ -307,24 +394,32 @@ def add_hatch(fig, hatch_data, extents=None, height=None, hatch_space=4,
extents = _get_extents(fig)
if height is None:
height = _get_height(fig)
- hatchim = _make_hatch_image(hatch_data, height, sampler, recache=recache,
- hatch_space=hatch_space)
- hatchim[:,:,0] = hatch_color[0]
- hatchim[:,:,1] = hatch_color[1]
- hatchim[:,:,2] = hatch_color[2]
+ hatchim = _make_hatch_image(
+ hatch_data, height, sampler, recache=recache, hatch_space=hatch_space
+ )
+ hatchim[:, :, 0] = hatch_color[0]
+ hatchim[:, :, 1] = hatch_color[1]
+ hatchim[:, :, 2] = hatch_color[2]
_, ax = _get_fig_and_ax(fig)
- img = ax.imshow(hatchim,
- aspect="equal",
- interpolation="bicubic",
- extent=extents,
- label='hatch',
- zorder=2)
+ img = ax.imshow(
+ hatchim,
+ aspect="equal",
+ interpolation="bicubic",
+ extent=extents,
+ label="hatch",
+ zorder=2,
+ )
return img
-def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0.2, 0.04),
- orientation='horizontal'):
+def add_colorbar(
+ fig,
+ cimg,
+ colorbar_ticks=None,
+ colorbar_location=(0.4, 0.07, 0.2, 0.04),
+ orientation="horizontal",
+):
"""Add a colorbar to a flatmap plot
Parameters
@@ -332,12 +427,12 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0
fig : matplotlib Figure object
Figure into which to insert colormap
cimg : matplotlib.image.AxesImage object
- Image for which to create colorbar. For reference, matplotlib.image.AxesImage
+ Image for which to create colorbar. For reference, matplotlib.image.AxesImage
is the output of imshow()
colorbar_ticks : array-like
values for colorbar ticks
colorbar_location : array-like
- Four-long list, tuple, or array that specifies location for colorbar axes
+ Four-long list, tuple, or array that specifies location for colorbar axes
[left, top, width, height] (?)
orientation : string
'vertical' or 'horizontal'
@@ -348,20 +443,25 @@ def add_colorbar(fig, cimg, colorbar_ticks=None, colorbar_location=(0.4, 0.07, 0
return cbar
-def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
- colorbar_location=(0.425, 0.02, 0.15, 0.15), fontsize=12):
+def add_colorbar_2d(
+ fig,
+ cmap_name,
+ colorbar_ticks,
+ colorbar_location=(0.425, 0.02, 0.15, 0.15),
+ fontsize=12,
+):
"""Add a 2D colorbar to a flatmap plot
Parameters
----------
fig : matplotlib Figure object
cimg : matplotlib.image.AxesImage object
- Image for which to create colorbar. For reference, matplotlib.image.AxesImage
+ Image for which to create colorbar. For reference, matplotlib.image.AxesImage
is the output of imshow()
colorbar_ticks : array-like
values for colorbar ticks
colorbar_location : array-like
- Four-long list, tuple, or array that specifies location for colorbar axes
+ Four-long list, tuple, or array that specifies location for colorbar axes
[left, top, width, height] (?)
orientation : string
'vertical' or 'horizontal'
@@ -369,11 +469,12 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
# a bit sketchy - lazy imports
import matplotlib.pyplot as plt
import os
- cmap_dir = config.get('webgl', 'colormaps')
- cim = plt.imread(os.path.join(cmap_dir, cmap_name + '.png'))
+
+ cmap_dir = config.get("webgl", "colormaps")
+ cim = plt.imread(os.path.join(cmap_dir, cmap_name + ".png"))
fig, _ = _get_fig_and_ax(fig)
fig.add_axes(colorbar_location)
- cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation='bilinear')
+ cbar = plt.imshow(cim, extent=colorbar_ticks, interpolation="bilinear")
cbar.axes.set_xticks(colorbar_ticks[:2])
cbar.axes.set_xticklabels(colorbar_ticks[:2], fontdict=dict(size=fontsize))
cbar.axes.set_yticks(colorbar_ticks[2:])
@@ -381,8 +482,18 @@ def add_colorbar_2d(fig, cmap_name, colorbar_ticks,
return cbar
-def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_labels=False,
- shape_list=None, **kwargs):
+
+def add_custom(
+ fig,
+ dataview,
+ svgfile,
+ layer,
+ extents=None,
+ height=None,
+ with_labels=False,
+ shape_list=None,
+ **kwargs,
+):
"""Add a custom data layer
Parameters
@@ -397,10 +508,10 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la
layer : string
Layer name within custom svg file to display
extents : array-like
- 4 values for [Left, Right, Bottom, Top] extents of image plotted. If None, defaults to
+ 4 values for [Left, Right, Bottom, Top] extents of image plotted. If None, defaults to
extents of images already present in figure.
height : scalar
- Height of image. if None, defaults to height of images already present in figure.
+ Height of image. if None, defaults to height of images already present in figure.
with_labels : bool
Whether to display text labels on ROIs
shape_list : list
@@ -419,6 +530,7 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la
"""
from ..svgoverlay import get_overlay
+
if height is None:
height = _get_height(fig)
if extents is None:
@@ -428,27 +540,37 @@ def add_custom(fig, dataview, svgfile, layer, extents=None, height=None, with_la
svg_kws = _convert_svg_kwargs(kwargs)
try:
# Check for layer if it exists
- layer_kws = _parse_defaults(layer+'_paths')
+ layer_kws = _parse_defaults(layer + "_paths")
layer_kws.update(svg_kws)
except:
layer_kws = svg_kws
- im = extra_svg.get_texture(layer, height,
- labels=with_labels,
- shape_list=shape_list,
- **layer_kws)
+ im = extra_svg.get_texture(
+ layer, height, labels=with_labels, shape_list=shape_list, **layer_kws
+ )
_, ax = _get_fig_and_ax(fig)
- img = ax.imshow(im,
- aspect="equal",
- interpolation="nearest",
- extent=extents,
- label='custom',
- zorder=6)
+ img = ax.imshow(
+ im,
+ aspect="equal",
+ interpolation="nearest",
+ extent=extents,
+ label="custom",
+ zorder=6,
+ )
return img
-def add_connected_vertices(fig, dataview, exclude_border_width=None,
- height=None, extents=None, recache=False,
- color=(1.0, 0.5, 0.1, 0.6), linewidth=0.75,
- alpha=1.0, **kwargs):
+
+def add_connected_vertices(
+ fig,
+ dataview,
+ exclude_border_width=None,
+ height=None,
+ extents=None,
+ recache=False,
+ color=(1.0, 0.5, 0.1, 0.6),
+ linewidth=0.75,
+ alpha=1.0,
+ **kwargs,
+):
"""Plot lines btw distant vertices that are within the same voxel
Parameters
@@ -492,11 +614,13 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None,
if extents is None:
extents = _get_extents(fig)
if height is None:
- height = _get_height(fig)
+ height = _get_height(fig)
subject = dataview.subject
xfmname = dataview.xfmname
if xfmname is None:
- raise ValueError("Dataview for add_connected_vertices must be a Volume! You seem to have provided vertex data.")
+ raise ValueError(
+ "Dataview for add_connected_vertices must be a Volume! You seem to have provided vertex data."
+ )
# print('computing shared voxels')
shared_voxels = db.get_shared_voxels(subject, xfmname, recache=recache, **kwargs)
# print('Finished computing shared voxels')
@@ -506,14 +630,20 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None,
if exclude_border_width:
# Finding vertices that map to the border of the flatmap
- img = np.nan * np.ones(mask.shape)
- img[mask] = pixmap * np.arange(n_verts) # mapper.nverts
+ img = np.nan * np.ones(mask.shape)
+ img[mask] = pixmap * np.arange(n_verts) # mapper.nverts
border_mask = binary_dilation(~mask, iterations=exclude_border_width) ^ (~mask)
border_vertices = set(img[border_mask].astype(int))
- shared_voxels = np.array([a for a in shared_voxels if ((a[1] not in border_vertices) and (a[2] not in border_vertices))])
+ shared_voxels = np.array(
+ [
+ a
+ for a in shared_voxels
+ if ((a[1] not in border_vertices) and (a[2] not in border_vertices))
+ ]
+ )
valid_vert_mask = np.array(pixmap.sum(0) > 0).flatten()
- valid_verts = np.arange(n_verts)[valid_vert_mask] # mapper.nverts
+ valid_verts = np.arange(n_verts)[valid_vert_mask] # mapper.nverts
# Assure both vertices in each pair are not in the medial wall
vtx1valid = np.isin(shared_voxels[:, 1], valid_verts)
vtx2valid = np.isin(shared_voxels[:, 2], valid_verts)
@@ -532,16 +662,21 @@ def add_connected_vertices(fig, dataview, exclude_border_width=None,
# (This is the most time consuming step, as it draws many lines)
# print('plotting lines...')
fig, ax = _get_fig_and_ax(fig)
- lc = LineCollection(pix_array_scaled,
- transform=fig.transFigure,
- figure=fig,
- colors=color,
- alpha=alpha,
- linewidths=linewidth)
+ lc = LineCollection(
+ pix_array_scaled,
+ transform=fig.transFigure,
+ figure=fig,
+ colors=color,
+ alpha=alpha,
+ linewidths=linewidth,
+ )
lc_object = ax.add_collection(lc)
return lc_object
-def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, overlay_file=None):
+
+def add_cutout(
+ fig, name, dataview, layers=None, height=None, extents=None, overlay_file=None
+):
"""Apply a cutout mask to extant layers in flatmap figure
Parameters
@@ -574,13 +709,12 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over
for co_name, co_shape in svgobject.cutouts.shapes.items():
co_shape.visible = co_name == name
# Get cutout image (now all white = 1, black = 0)
- svg_kws = _convert_svg_kwargs(dict(fillcolor="white",
- fillalpha=1.0,
- linecolor="white",
- linewidth=2))
- co = svgobject.get_texture('cutouts', height, labels=False, **svg_kws)[..., 0]
+ svg_kws = _convert_svg_kwargs(
+ dict(fillcolor="white", fillalpha=1.0, linecolor="white", linewidth=2)
+ )
+ co = svgobject.get_texture("cutouts", height, labels=False, **svg_kws)[..., 0]
if not np.any(co):
- raise Exception('No pixels in cutout region {}!'.format(name))
+ raise Exception("No pixels in cutout region {}!".format(name))
# Bounding box indices
LL, RR, BB, TT = np.nan, np.nan, np.nan, np.nan
@@ -588,41 +722,49 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over
for layer_name, im_layer in layers.items():
im = im_layer.get_array()
- # Reconcile occasional 1-pixel difference between flatmap image layers
+ # Reconcile occasional 1-pixel difference between flatmap image layers
# that are generated by different functions
if not all([np.abs(aa - bb) <= 1 for aa, bb in zip(im.shape, co.shape)]):
raise Exception("Shape mismatch btw cutout and data!")
- if any([np.abs(aa - bb) > 0 and np.abs(aa - bb) < 2 for aa, bb in zip(im.shape, co.shape)]):
+ if any(
+ [
+ np.abs(aa - bb) > 0 and np.abs(aa - bb) < 2
+ for aa, bb in zip(im.shape, co.shape)
+ ]
+ ):
from scipy.misc import imresize
- print('Resizing! {} to {}'.format(co.shape, im.shape[:2]))
- layer_cutout = imresize(co, im.shape[:2]).astype(np.float32)/255.
+
+ print("Resizing! {} to {}".format(co.shape, im.shape[:2]))
+ layer_cutout = imresize(co, im.shape[:2]).astype(np.float32) / 255.0
else:
layer_cutout = copy.copy(co)
# Handle different types of alpha layers. Useful for RGBVolumes if nothing else.
if im.dtype == np.uint8:
- im = np.cast['float32'](im)/255.
- im[:,:,3] *= layer_cutout
+ im = np.cast["float32"](im) / 255.0
+ im[:, :, 3] *= layer_cutout
h, w, cdim = [float(v) for v in im.shape]
else:
- if np.ndim(im)==3:
- im[:,:,3] *= layer_cutout
+ if np.ndim(im) == 3:
+ im[:, :, 3] *= layer_cutout
h, w, cdim = [float(v) for v in im.shape]
- elif np.ndim(im)==2:
- im[layer_cutout==0] = np.nan
+ elif np.ndim(im) == 2:
+ im[layer_cutout == 0] = np.nan
h, w = [float(v) for v in im.shape]
y, x = np.nonzero(layer_cutout)
l, r, b, t = extents
- x_span = np.abs(r-l)
- y_span = np.abs(t-b)
- extents_new = [l + x.min() / w * x_span,
- l + x.max() / w * x_span,
- t + y.min() / h * y_span,
- t + y.max() / h * y_span]
+ x_span = np.abs(r - l)
+ y_span = np.abs(t - b)
+ extents_new = [
+ l + x.min() / w * x_span,
+ l + x.max() / w * x_span,
+ t + y.min() / h * y_span,
+ t + y.max() / h * y_span,
+ ]
# Bounding box indices
iy, ix = ((y.min(), y.max()), (x.min(), x.max()))
- tmp = im[iy[0]:iy[1], ix[0]:ix[1]]
+ tmp = im[iy[0] : iy[1], ix[0] : ix[1]]
im_layer.set_array(tmp)
im_layer.set_extent(extents_new)
@@ -641,3 +783,127 @@ def add_cutout(fig, name, dataview, layers=None, height=None, extents=None, over
fig.set_size_inches(inch_size[0], inch_size[1])
return
+
+
+def _detect_label_borders(label_img):
+ """Detect border pixels in a 2D label image.
+
+ A pixel is a border if any of its 4-connected neighbors has a different
+ value. NaN pixels (outside brain mask) are not borders.
+
+ Parameters
+ ----------
+ label_img : ndarray, shape (H, W) or (H, W, C)
+ Label image. If 3D (e.g. RGBA), the first channel is used.
+
+ Returns
+ -------
+ border : ndarray, shape (H, W), dtype bool
+ True at border pixels.
+ """
+ if label_img.ndim == 3:
+ vals = label_img[:, :, 0]
+ else:
+ vals = label_img
+
+ valid = ~np.isnan(vals)
+ border = np.zeros(vals.shape, dtype=bool)
+ for di, dj in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
+ shifted = np.roll(np.roll(vals, di, axis=0), dj, axis=1)
+ shifted_valid = ~np.isnan(shifted)
+ # Invalidate wrapped edges to avoid false borders from np.roll
+ if di == -1:
+ shifted_valid[-1, :] = False
+ elif di == 1:
+ shifted_valid[0, :] = False
+ if dj == -1:
+ shifted_valid[:, -1] = False
+ elif dj == 1:
+ shifted_valid[:, 0] = False
+ border |= (vals != shifted) & valid & shifted_valid
+ return border
+
+
+def add_contours(
+ fig,
+ contour_data,
+ extents=None,
+ height=None,
+ linewidth=1,
+ linecolor=(0, 0, 0, 1),
+ sampler="nearest",
+ recache=False,
+):
+ """Add contour borders of parcellation data to a quickflat plot.
+
+ Parameters
+ ----------
+ fig : matplotlib figure or axes
+ Figure or axes to plot into.
+ contour_data : cortex.Dataview
+ Parcellation/label data whose borders will be drawn.
+ extents : array-like, optional
+ [Left, Right, Top, Bottom] extents. None uses extents from existing images.
+ height : int, optional
+ Height of the flatmap image. None uses height of existing images.
+ linewidth : int
+ Width of contour lines in pixels (default: 1).
+ linecolor : tuple of float
+ (R, G, B, A) color for contour lines (default: black).
+ sampler : str
+ Sampling method (default: 'nearest' to preserve label boundaries).
+ recache : bool
+ Whether to recache intermediate files.
+
+ Returns
+ -------
+ img : matplotlib.image.AxesImage
+ Axes image object for plotted contour overlay.
+ """
+ dataview = dataset.normalize(contour_data)
+ if not isinstance(dataview, dataset.Dataview):
+ raise TypeError("Please provide a Dataview (e.g. cortex.Vertex), not a Dataset")
+
+ try:
+ if extents is None:
+ extents = _get_extents(fig)
+ except ValueError:
+ extents = None
+ if height is None:
+ try:
+ height = _get_height(fig)
+ except (ValueError, IndexError):
+ height = 1024
+
+ # Generate flatmap image of the label data
+ label_img, extents_out = make_flatmap_image(
+ dataview, height=height, recache=recache, sampler=sampler
+ )
+
+ if extents is None:
+ extents = extents_out
+
+ # Detect borders
+ border = _detect_label_borders(label_img)
+
+ # Optionally dilate for thicker lines
+ if linewidth > 1:
+ from scipy.ndimage import binary_dilation
+
+ struct = np.ones((linewidth, linewidth))
+ border = binary_dilation(border, structure=struct)
+
+ # Create RGBA overlay image
+ rgba = np.zeros(label_img.shape[:2] + (4,), dtype=np.float32)
+ rgba[border] = linecolor
+
+ _, ax = _get_fig_and_ax(fig)
+ img = ax.imshow(
+ rgba,
+ aspect="equal",
+ extent=extents,
+ interpolation="nearest",
+ zorder=4,
+ label="contours",
+ )
+ return img
diff --git a/cortex/quickflat/view.py b/cortex/quickflat/view.py
index 1d7102905..0b6cdf7ea 100644
--- a/cortex/quickflat/view.py
+++ b/cortex/quickflat/view.py
@@ -4,7 +4,7 @@
import binascii
import numpy as np
import numpy.typing as npt
-from typing import Optional, Union, IO
+from typing import Literal, Optional, Union, IO
from matplotlib.axes import Axes
from matplotlib.figure import Figure
@@ -16,33 +16,68 @@
default_colorbar_locations = {
- 'left': (.0, .07, .2, .04),
- 'center': (.4, .07, .2, .04),
- 'right': (.7, .07, .2, .04)
+ "left": (0.0, 0.07, 0.2, 0.04),
+ "center": (0.4, 0.07, 0.2, 0.04),
+ "right": (0.7, 0.07, 0.2, 0.04),
}
-def _check_colorbar_location(colorbar_location: Union[tuple[float, float, float, float], str]) -> tuple[float, float, float, float]:
+def _check_colorbar_location(
+ colorbar_location: Union[tuple[float, float, float, float], str],
+) -> tuple[float, float, float, float]:
if isinstance(colorbar_location, (tuple, list)):
return colorbar_location
if colorbar_location not in default_colorbar_locations:
- raise ValueError("colorbar_location must be one of {}".format(
- list(default_colorbar_locations.keys())))
+ raise ValueError(
+ "colorbar_location must be one of {}".format(
+ list(default_colorbar_locations.keys())
+ )
+ )
return default_colorbar_locations[colorbar_location]
-def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: bool=True, thick: int=32, sampler: str='nearest',
- height: int=1024, dpi: int=100, depth: float=0.5, with_rois: bool=True, with_sulci: bool=False,
- with_labels: bool=True, with_colorbar: bool=True, with_borders: bool=False,
- with_dropout: Union[bool, float]=False, with_curvature: bool=False, extra_disp: Optional[tuple[str, str]]=None,
- with_connected_vertices: bool=False, overlay_file: Optional[str]=None,
- linewidth: Optional[int]=None, linecolor: Optional[ColorType]=None, roifill: Optional[ColorType]=None, shadow: Optional[int]=None,
- labelsize: Optional[str]=None, labelcolor: Optional[ColorType]=None, cutout: Optional[str]=None, curvature_brightness: Optional[float]=None,
- curvature_contrast: Optional[float]=None, curvature_threshold: Optional[bool]=None, fig: Optional[Union[Figure, Axes]]=None, extra_hatch: Optional[tuple[dataset.Dataview, tuple[float, float, float]]]=None,
- colorbar_ticks: Optional[npt.ArrayLike]=None, colorbar_location: Union[tuple[float, float, float, float], str]='center', roi_list: Optional[list[str]]=None, sulci_list: Optional[list[str]]=None,
- nanmean: bool=False) -> Figure:
+def make_figure(
+ braindata: dataset.Dataview,
+ recache: bool = False,
+ pixelwise: bool = True,
+ thick: int = 32,
+ sampler: str = "nearest",
+ height: int = 1024,
+ dpi: int = 100,
+ depth: float = 0.5,
+ with_rois: bool = True,
+ with_sulci: bool = False,
+ with_labels: bool = True,
+ with_colorbar: bool = True,
+ with_borders: bool = False,
+ with_dropout: Union[bool, float] = False,
+ with_curvature: bool = False,
+ extra_disp: Optional[tuple[str, str]] = None,
+ with_connected_vertices: bool = False,
+ overlay_file: Optional[str] = None,
+ linewidth: Optional[int] = None,
+ linecolor: Optional[ColorType] = None,
+ roifill: Optional[ColorType] = None,
+ shadow: Optional[int] = None,
+ labelsize: Optional[str] = None,
+ labelcolor: Optional[ColorType] = None,
+ cutout: Optional[str] = None,
+ curvature_brightness: Optional[float] = None,
+ curvature_contrast: Optional[float] = None,
+ curvature_threshold: Optional[bool] = None,
+ fig: Optional[Union[Figure, Axes]] = None,
+ extra_hatch: Optional[tuple[dataset.Dataview, tuple[float, float, float]]] = None,
+ colorbar_ticks: Optional[npt.ArrayLike] = None,
+ colorbar_location: Union[tuple[float, float, float, float], str] = "center",
+ roi_list: Optional[list[str]] = None,
+ sulci_list: Optional[list[str]] = None,
+ nanmean: bool = False,
+ with_contours: Union[Literal[False], dataset.Dataview] = False,
+ contour_linewidth: Optional[int] = None,
+ contour_linecolor: Optional[ColorType] = None,
+) -> Figure:
"""Show a Volume or Vertex on a flatmap with matplotlib.
Parameters
@@ -123,12 +158,22 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo
figure into which to plot flatmap
nanmean : bool, optional (default = False)
If True, NaNs in the data will be ignored when averaging across layers.
+ with_contours : cortex.Dataview or False, optional
+ Parcellation data whose label boundaries will be drawn as contour
+ lines on top of the plotted data. Pass a Vertex (or other Dataview)
+ with discrete labels. False (default) disables contours.
+ contour_linewidth : int, optional
+ Width of contour lines in pixels. None defaults to 1.
+ contour_linecolor : tuple of float, optional
+ (R, G, B, A) color for contour lines. None defaults to black.
"""
from matplotlib import pyplot as plt
dataview = dataset.normalize(braindata)
if not isinstance(dataview, dataset.Dataview):
- raise TypeError('Please provide a Dataview (e.g. an instance of cortex.Volume, cortex.Vertex, etc), not a Dataset')
+ raise TypeError(
+ "Please provide a Dataview (e.g. an instance of cortex.Volume, cortex.Vertex, etc), not a Dataset"
+ )
if fig is None:
fig_resize = True
fig = plt.figure()
@@ -143,20 +188,33 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo
fig = ax.figure
# Add data
- data_im, extents = composite.add_data(ax, dataview, pixelwise=pixelwise, thick=thick, sampler=sampler,
- height=height, depth=depth, recache=recache, nanmean=nanmean)
+ data_im, extents = composite.add_data(
+ ax,
+ dataview,
+ pixelwise=pixelwise,
+ thick=thick,
+ sampler=sampler,
+ height=height,
+ depth=depth,
+ recache=recache,
+ nanmean=nanmean,
+ )
layers = dict(data=data_im)
# Add curvature
if with_curvature:
- curv_im = composite.add_curvature(ax, dataview, extents,
- brightness=curvature_brightness,
- contrast=curvature_contrast,
- threshold=curvature_threshold,
- curvature_lims=0.5,
- legacy_mode=False,
- recache=recache)
- layers['curvature'] = curv_im
+ curv_im = composite.add_curvature(
+ ax,
+ dataview,
+ extents,
+ brightness=curvature_brightness,
+ contrast=curvature_contrast,
+ threshold=curvature_threshold,
+ curvature_lims=0.5,
+ legacy_mode=False,
+ recache=recache,
+ )
+ layers["curvature"] = curv_im
# Add dropout
if with_dropout is not False:
@@ -167,43 +225,105 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo
hatch_data = None
dropout_power = 20 if with_dropout is True else with_dropout
if hatch_data is None:
- hatch_data = utils.get_dropout(dataview.subject, dataview.xfmname,
- power=dropout_power)
+ hatch_data = utils.get_dropout(
+ dataview.subject, dataview.xfmname, power=dropout_power
+ )
- drop_im = composite.add_hatch(ax, hatch_data, extents=extents, height=height,
- sampler=sampler, recache=recache)
- layers['dropout'] = drop_im
+ drop_im = composite.add_hatch(
+ ax,
+ hatch_data,
+ extents=extents,
+ height=height,
+ sampler=sampler,
+ recache=recache,
+ )
+ layers["dropout"] = drop_im
# Add extra hatching
if extra_hatch is not None:
hatch_data2, hatch_color = extra_hatch
- hatch_im = composite.add_hatch(ax, hatch_data2, extents=extents, height=height,
- sampler=sampler, recache=recache)
- layers['hatch'] = hatch_im
+ hatch_im = composite.add_hatch(
+ ax,
+ hatch_data2,
+ extents=extents,
+ height=height,
+ sampler=sampler,
+ recache=recache,
+ )
+ layers["hatch"] = hatch_im
# Add rois
if with_rois:
- roi_im = composite.add_rois(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor,
- roifill=roifill, shadow=shadow, labelsize=labelsize, labelcolor=labelcolor,
- with_labels=with_labels, overlay_file=overlay_file,
- roi_list=roi_list)
- layers['rois'] = roi_im
+ roi_im = composite.add_rois(
+ ax,
+ dataview,
+ extents=extents,
+ height=height,
+ linewidth=linewidth,
+ linecolor=linecolor,
+ roifill=roifill,
+ shadow=shadow,
+ labelsize=labelsize,
+ labelcolor=labelcolor,
+ with_labels=with_labels,
+ overlay_file=overlay_file,
+ roi_list=roi_list,
+ )
+ layers["rois"] = roi_im
# Add sulci
if with_sulci:
- sulc_im = composite.add_sulci(ax, dataview, extents=extents, height=height, linewidth=linewidth, linecolor=linecolor,
- shadow=shadow, labelsize=labelsize, labelcolor=labelcolor, with_labels=with_labels,
- overlay_file=overlay_file, sulci_list=sulci_list)
- layers['sulci'] = sulc_im
+ sulc_im = composite.add_sulci(
+ ax,
+ dataview,
+ extents=extents,
+ height=height,
+ linewidth=linewidth,
+ linecolor=linecolor,
+ shadow=shadow,
+ labelsize=labelsize,
+ labelcolor=labelcolor,
+ with_labels=with_labels,
+ overlay_file=overlay_file,
+ sulci_list=sulci_list,
+ )
+ layers["sulci"] = sulc_im
# Add custom
if extra_disp is not None:
svgfile, layer = extra_disp
- custom_im = composite.add_custom(ax, dataview, svgfile, layer, height=height, extents=extents,
- linewidth=linewidth, linecolor=linecolor, shadow=shadow, labelsize=labelsize,
- labelcolor=labelcolor, with_labels=with_labels)
- layers['custom'] = custom_im
+ custom_im = composite.add_custom(
+ ax,
+ dataview,
+ svgfile,
+ layer,
+ height=height,
+ extents=extents,
+ linewidth=linewidth,
+ linecolor=linecolor,
+ shadow=shadow,
+ labelsize=labelsize,
+ labelcolor=labelcolor,
+ with_labels=with_labels,
+ )
+ layers["custom"] = custom_im
+ # Add contours
+ if with_contours is not False:
+ contour_kw = {}
+ if contour_linewidth is not None:
+ contour_kw["linewidth"] = contour_linewidth
+ if contour_linecolor is not None:
+ contour_kw["linecolor"] = contour_linecolor
+ contour_im = composite.add_contours(
+ ax,
+ with_contours,
+ extents=extents,
+ height=height,
+ recache=recache,
+ **contour_kw,
+ )
+ layers["contours"] = contour_im
# Add connector lines btw connected vertices
if with_connected_vertices:
vertex_lines = composite.add_connected_vertices(ax, dataview, recache=recache)
- ax.axis('off')
+ ax.axis("off")
ax.set_xlim(extents[0], extents[1])
ax.set_ylim(extents[2], extents[3])
@@ -213,32 +333,44 @@ def make_figure(braindata: dataset.Dataview, recache: bool=False, pixelwise: boo
# Add (apply) cutout of flatmap
if cutout is not None:
- extents = composite.add_cutout(ax, cutout, dataview, layers, overlay_file=overlay_file)
+ extents = composite.add_cutout(
+ ax, cutout, dataview, layers, overlay_file=overlay_file
+ )
if with_colorbar:
colorbar_location = _check_colorbar_location(colorbar_location)
# Allow 2D colorbars:
if isinstance(dataview, dataset.Dataview2D):
- colorbar_ticks = np.round([
- dataview.vmin, dataview.vmax,
- dataview.vmin2, dataview.vmax2
- ], 2)
+ colorbar_ticks = np.round(
+ [dataview.vmin, dataview.vmax, dataview.vmin2, dataview.vmax2], 2
+ )
colorbar = composite.add_colorbar_2d(
- ax, dataview.cmap, colorbar_ticks,
- colorbar_location=colorbar_location)
+ ax, dataview.cmap, colorbar_ticks, colorbar_location=colorbar_location
+ )
else:
colorbar = composite.add_colorbar(
- ax, data_im,
+ ax,
+ data_im,
colorbar_location=colorbar_location,
- colorbar_ticks=colorbar_ticks
+ colorbar_ticks=colorbar_ticks,
)
# Reset axis to main figure axis
plt.sca(ax)
return fig
-def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, recache: bool=False, pixelwise: bool=True, sampler: str='nearest', height: int=1024,
- bgcolor: Optional[ColorType]=None, dpi: int=100, **kwargs) -> None:
+
+def make_png(
+ fname: Union[str, os.PathLike, IO],
+ braindata: dataset.Dataview,
+ recache: bool = False,
+ pixelwise: bool = True,
+ sampler: str = "nearest",
+ height: int = 1024,
+ bgcolor: Optional[ColorType] = None,
+ dpi: int = 100,
+ **kwargs,
+) -> None:
"""Create a PNG of the VertexData or VolumeData on a flatmap.
Parameters
@@ -289,12 +421,15 @@ def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, re
(R, G, B, A) specification for the label color
"""
from matplotlib import pyplot as plt
- fig = make_figure(braindata,
- recache=recache,
- pixelwise=pixelwise,
- sampler=sampler,
- height=height,
- **kwargs)
+
+ fig = make_figure(
+ braindata,
+ recache=recache,
+ pixelwise=pixelwise,
+ sampler=sampler,
+ height=height,
+ **kwargs,
+ )
imsize = fig.get_axes()[0].get_images()[0].get_size()
fig.set_size_inches(np.array(imsize)[::-1] / float(dpi))
@@ -305,8 +440,18 @@ def make_png(fname: Union[str, os.PathLike, IO], braindata: dataset.Dataview, re
fig.clf()
plt.close(fig)
-def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=['rois'],
- height=1024, overlay_file=None, with_dropout=False, **kwargs):
+
+def make_svg(
+ fname,
+ braindata,
+ with_labels=False,
+ with_curvature=True,
+ layers=["rois"],
+ height=1024,
+ overlay_file=None,
+ with_dropout=False,
+ **kwargs,
+):
"""Save an svg file of the desired flatmap.
This function creates an SVG file with vector graphic ROIs overlaid on a single png image.
@@ -344,9 +489,9 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=['
arr, extents = make_flatmap_image(braindata, height=height, **kwargs)
# Set nans to alpha = 0. to enable transparency when saving as PNG
mask_nans = np.isnan(arr[..., 3])
- arr[mask_nans, 3] = 0.
+ arr[mask_nans, 3] = 0.0
- if hasattr(braindata, 'cmap'):
+ if hasattr(braindata, "cmap"):
imsave(fp, arr, cmap=braindata.cmap, vmin=braindata.vmin, vmax=braindata.vmax)
else:
imsave(fp, arr)
@@ -357,6 +502,7 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=['
if with_curvature:
# no options. learn to love it.
from cortex import db
+
fpc = io.BytesIO()
curv_vertices = db.get_surfinfo(braindata.subject)
curv_arr, _ = make_flatmap_image(curv_vertices, height=height)
@@ -364,7 +510,7 @@ def make_svg(fname, braindata, with_labels=False, with_curvature=True, layers=['
curv_arr = np.where(curv_arr > 0, 0.5, 0.25)
curv_arr[mask] = np.nan
- imsave(fpc, curv_arr, cmap='Greys_r', vmin=0, vmax=1)
+ imsave(fpc, curv_arr, cmap="Greys_r", vmin=0, vmax=1)
fpc.seek(0)
image_data = [binascii.b2a_base64(fpc.read()), pngdata]
@@ -425,16 +571,16 @@ def make_gif(output_destination, volumes, frame_duration=1, **figure_kwargs):
fig = plt.figure(figsize=(12, 6), dpi=100)
_ = make_figure(volumes[name], fig=fig, **figure_kwargs)
_ = fig.suptitle(name)
- path = os.path.join(tmpdir.name, str(i) + '.png')
+ path = os.path.join(tmpdir.name, str(i) + ".png")
fig.savefig(path)
images.append(imageio.imread(path))
_ = plt.close(fig)
tmpdir.cleanup()
- imageio.mimsave(output_destination, images, format='gif', duration=frame_duration)
+ imageio.mimsave(output_destination, images, format="gif", duration=frame_duration)
- if hasattr(output_destination, 'seek'):
+ if hasattr(output_destination, "seek"):
output_destination.seek(0)
@@ -442,9 +588,25 @@ def show(*args, **kwargs):
"""Wrapper for make_figure()"""
return make_figure(*args, **kwargs)
-def make_movie(name, data, subject, xfmname, recache=False, height=1024,
- sampler='nearest', dpi=100, tr=2, interp='linear', fps=30,
- vcodec='libtheora', bitrate="8000k", vmin=None, vmax=None, **kwargs):
+
+def make_movie(
+ name,
+ data,
+ subject,
+ xfmname,
+ recache=False,
+ height=1024,
+ sampler="nearest",
+ dpi=100,
+ tr=2,
+ interp="linear",
+ fps=30,
+ vcodec="libtheora",
+ bitrate="8000k",
+ vmin=None,
+ vmax=None,
+ **kwargs,
+):
"""Create a movie of an 4D data set"""
raise NotImplementedError
import sys
@@ -457,7 +619,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024,
from scipy.interpolate import interp1d
# Make the flatmaps
- ims, extents = make_flatmap_image(data, subject, xfmname, recache=recache, height=height, sampler=sampler)
+ ims, extents = make_flatmap_image(
+ data, subject, xfmname, recache=recache, height=height, sampler=sampler
+ )
if vmin is None:
vmin = np.nanmin(ims)
if vmax is None:
@@ -469,9 +633,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024,
img = fig.axes[0].images[0]
# Set up interpolation
- times = np.arange(0, len(ims)*tr, tr)
+ times = np.arange(0, len(ims) * tr, tr)
interp = interp1d(times, ims, kind=interp, axis=0, copy=False)
- frames = np.linspace(0, times[-1], (len(times)-1)*tr*fps+1)
+ frames = np.linspace(0, times[-1], (len(times) - 1) * tr * fps + 1)
try:
path = tempfile.mkdtemp()
@@ -481,7 +645,9 @@ def make_movie(name, data, subject, xfmname, recache=False, height=1024,
fig.savefig(impath.format(frame), transparent=True, dpi=dpi)
# avconv might not be relevant function for all operating systems.
# Introduce operating system check here?
- cmd = "avconv -i {path} -vcodec {vcodec} -r {fps} -b {br} {name}".format(path=impath, vcodec=vcodec, fps=fps, br=bitrate, name=name)
+ cmd = "avconv -i {path} -vcodec {vcodec} -r {fps} -b {br} {name}".format(
+ path=impath, vcodec=vcodec, fps=fps, br=bitrate, name=name
+ )
sp.call(shlex.split(cmd))
finally:
shutil.rmtree(path)
diff --git a/cortex/tests/test_contours.py b/cortex/tests/test_contours.py
new file mode 100644
index 000000000..80bcd4ddb
--- /dev/null
+++ b/cortex/tests/test_contours.py
@@ -0,0 +1,301 @@
+"""Tests for contour/border rendering of parcellation data.
+
+Tests cover:
+- Python utility: get_contour_vertices()
+- Quickflat: add_contours(), _detect_label_borders(), make_figure(with_contours=...)
+- WebGL: shader contour uniforms and geometry attributes
+"""
+
+import numpy as np
+import pytest
+
+import cortex
+from cortex.quickflat.composite import _detect_label_borders
+from cortex.testing_utils import has_installed
+
+no_inkscape = not has_installed("inkscape")
+
+SUBJECT = "S1"
+
+
+def _make_parcellation(subject=SUBJECT):
+ """Create parcellation vertex data from existing ROIs.
+
+ Returns array of shape (n_vertices,) with integer labels per ROI
+ and 0 for vertices not in any ROI.
+ """
+ import warnings
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ roi_verts = cortex.get_roi_verts(subject)
+ n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0]
+ parcellation = np.zeros(n_verts, dtype=float)
+ for i, (name, verts) in enumerate(roi_verts.items(), start=1):
+ parcellation[np.asarray(verts, dtype=int)] = float(i)
+ return parcellation
+
+
+# --- Tests for Python contour utility ---
+
+
+class TestGetContourVertices:
+ def test_returns_border_vertices(self):
+ """get_contour_vertices should return True at vertices bordering
+ different label values."""
+ parcellation = _make_parcellation()
+ border = cortex.utils.get_contour_vertices(parcellation, SUBJECT)
+ assert border.dtype == bool
+ assert border.shape == parcellation.shape
+ # There should be some border vertices (parcellation has multiple labels)
+ assert border.sum() > 0
+ # Border vertices should be fewer than total labeled vertices
+ assert border.sum() < (parcellation > 0).sum()
+
+ def test_uniform_data_has_no_borders(self):
+ """Uniform data (all same label) should have no border vertices."""
+ n_verts = cortex.db.get_surf(SUBJECT, "fiducial", merge=True)[0].shape[0]
+ uniform = np.ones(n_verts, dtype=float)
+ border = cortex.utils.get_contour_vertices(uniform, SUBJECT)
+ assert border.sum() == 0
+
+ def test_border_vertices_are_adjacent_to_different_labels(self):
+ """Every border vertex should have at least one neighbor with a
+ different label."""
+ parcellation = _make_parcellation()
+ border = cortex.utils.get_contour_vertices(parcellation, SUBJECT)
+
+ _, polys = cortex.db.get_surf(SUBJECT, "fiducial", merge=True)
+ neighbors = cortex.utils._get_neighbors_dict(polys)
+
+ # Check a sample of border vertices
+ border_verts = np.where(border)[0][:100]
+ for v in border_verts:
+ neighbor_labels = {
+ parcellation[n] for n in neighbors[v] if n < len(parcellation)
+ }
+ assert (
+ len(neighbor_labels) > 1 or parcellation[v] not in neighbor_labels
+ ), f"Border vertex {v} has no neighbor with different label"
+
+
+# --- Tests for _detect_label_borders (2D image helper) ---
+
+
+class TestDetectLabelBorders:
+ def test_uniform_image_no_borders(self):
+ """Uniform label image should have no borders."""
+ img = np.ones((50, 50))
+ border = _detect_label_borders(img)
+ assert border.sum() == 0
+
+ def test_two_regions_has_border(self):
+ """Image split into two regions should have a border between them."""
+ img = np.ones((50, 50))
+ img[:, 25:] = 2.0
+ border = _detect_label_borders(img)
+ # Border should be at column 24 and 25 (the boundary pixels)
+ assert border.sum() > 0
+ # Border pixels should be in the middle columns
+ border_cols = np.where(border.any(axis=0))[0]
+ assert 24 in border_cols or 25 in border_cols
+
+ def test_nan_pixels_are_not_borders(self):
+ """NaN pixels (outside brain mask) should not be marked as borders."""
+ img = np.full((50, 50), np.nan)
+ img[10:40, 10:40] = 1.0
+ img[10:40, 25:40] = 2.0
+ border = _detect_label_borders(img)
+ # Should have borders between label 1 and 2, but not at NaN edges
+ # (NaN-to-value transitions should not count as borders since
+ # we're interested in parcel-to-parcel boundaries, not brain mask edges)
+ nan_mask = np.isnan(img)
+ assert not border[nan_mask].any()
+
+ def test_3d_image_uses_first_channel(self):
+ """For RGBA images, borders should be detected on first channel."""
+ img = np.ones((50, 50, 4))
+ img[:, 25:, 0] = 2.0
+ border = _detect_label_borders(img)
+ assert border.sum() > 0
+
+
+# --- Tests for quickflat contour rendering ---
+
+
+class TestQuickflatContours:
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_add_contours_returns_image(self):
+ """add_contours() should return a matplotlib AxesImage."""
+ from matplotlib import pyplot as plt
+ from cortex.quickflat.composite import add_contours
+
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+
+ fig, ax = plt.subplots()
+ img = add_contours(ax, parc_vertex, height=256)
+ assert img is not None
+ plt.close(fig)
+
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_add_contours_border_pixels_are_opaque(self):
+ """Contour overlay should have opaque pixels only at label borders."""
+ from cortex.quickflat.composite import add_contours
+ from matplotlib import pyplot as plt
+
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+
+ fig, ax = plt.subplots()
+ img = add_contours(ax, parc_vertex, height=256)
+ rgba = img.get_array()
+ # Alpha channel should be > 0 only at border pixels
+ has_content = rgba[:, :, 3] > 0 if rgba.ndim == 3 else rgba > 0
+ assert has_content.any(), "No contour pixels found"
+ plt.close(fig)
+
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_make_figure_with_contours(self):
+ """make_figure() with with_contours should produce a figure with
+ contour overlay."""
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+ activation = cortex.Vertex(
+ np.random.randn(parcellation.shape[0]), SUBJECT, cmap="hot", vmin=-2, vmax=2
+ )
+
+ fig = cortex.quickflat.make_figure(
+ activation,
+ with_contours=parc_vertex,
+ with_rois=False,
+ with_colorbar=False,
+ height=256,
+ )
+ assert fig is not None
+ # Should have at least 2 images: data + contours
+ ax = fig.get_axes()[0]
+ images = ax.get_images()
+ assert len(images) >= 2
+
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_make_figure_volume_with_vertex_contours(self):
+ """make_figure() should work with Volume data + Vertex contour overlay."""
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+ volume = cortex.Volume.random(subject=SUBJECT, xfmname="fullhead")
+
+ fig = cortex.quickflat.make_figure(
+ volume,
+ with_contours=parc_vertex,
+ with_rois=False,
+ with_colorbar=False,
+ height=256,
+ )
+ assert fig is not None
+ ax = fig.get_axes()[0]
+ images = ax.get_images()
+ assert len(images) >= 2
+
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_contour_linewidth(self):
+ """Thicker linewidth should produce more border pixels."""
+ from cortex.quickflat.composite import add_contours
+ from matplotlib import pyplot as plt
+
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+
+ fig1, ax1 = plt.subplots()
+ img1 = add_contours(ax1, parc_vertex, height=256, linewidth=1)
+ rgba1 = img1.get_array()
+ n_pixels_1 = (
+ (rgba1[:, :, 3] > 0).sum() if rgba1.ndim == 3 else (rgba1 > 0).sum()
+ )
+
+ fig2, ax2 = plt.subplots()
+ img2 = add_contours(ax2, parc_vertex, height=256, linewidth=3)
+ rgba2 = img2.get_array()
+ n_pixels_3 = (
+ (rgba2[:, :, 3] > 0).sum() if rgba2.ndim == 3 else (rgba2 > 0).sum()
+ )
+
+ assert (
+ n_pixels_3 > n_pixels_1
+ ), f"linewidth=3 ({n_pixels_3}) should have more pixels than linewidth=1 ({n_pixels_1})"
+ plt.close("all")
+
+ @pytest.mark.skipif(no_inkscape, reason="Inkscape required")
+ def test_contour_linecolor(self):
+ """Custom linecolor should appear in the contour image."""
+ from cortex.quickflat.composite import add_contours
+ from matplotlib import pyplot as plt
+
+ parcellation = _make_parcellation()
+ parc_vertex = cortex.Vertex(parcellation, SUBJECT)
+
+ red = (1.0, 0.0, 0.0, 1.0)
+ fig, ax = plt.subplots()
+ img = add_contours(ax, parc_vertex, height=256, linecolor=red)
+ rgba = img.get_array()
+ if rgba.ndim == 3:
+ border_mask = rgba[:, :, 3] > 0
+ # Red channel should be 1.0 at border pixels
+ np.testing.assert_allclose(rgba[border_mask, 0], 1.0)
+ # Green and blue should be 0
+ np.testing.assert_allclose(rgba[border_mask, 1], 0.0)
+ np.testing.assert_allclose(rgba[border_mask, 2], 0.0)
+ plt.close(fig)
+
+
+# --- Tests for WebGL contour support ---
+
+
+class TestWebGLContours:
+ def test_shader_includes_contour_uniforms(self):
+ """Both surface_vertex and surface_pixel shaders should include
+ contour-related uniforms."""
+ import os
+
+ shader_path = os.path.join(
+ os.path.dirname(cortex.__file__), "webgl", "resources", "js", "shaderlib.js"
+ )
+ with open(shader_path, "r") as f:
+ shader_code = f.read()
+
+ assert "contourMode" in shader_code
+ assert "contourThreshold" in shader_code
+ assert "contourColor" in shader_code
+ assert "fwidth" in shader_code
+ assert "vDataValue" in shader_code
+ assert "vContourDataValue" in shader_code
+ assert "contourColormap" in shader_code
+
+ def test_geometry_has_contour_attributes(self):
+ """mriview_surface.js should initialize contourData attributes."""
+ import os
+
+ surface_path = os.path.join(
+ os.path.dirname(cortex.__file__),
+ "webgl",
+ "resources",
+ "js",
+ "mriview_surface.js",
+ )
+ with open(surface_path, "r") as f:
+ surface_code = f.read()
+
+ assert "contourData0" in surface_code
+ assert "contourData1" in surface_code
+
+ def test_viewer_has_contour_overlay_support(self):
+ """mriview.js should have contour overlay selection support."""
+ import os
+
+ viewer_path = os.path.join(
+ os.path.dirname(cortex.__file__), "webgl", "resources", "js", "mriview.js"
+ )
+ with open(viewer_path, "r") as f:
+ viewer_code = f.read()
+
+ assert "setContourOverlay" in viewer_code or "contour_overlay" in viewer_code
diff --git a/cortex/utils.py b/cortex/utils.py
index 8cfaf201d..40d9e1bba 100644
--- a/cortex/utils.py
+++ b/cortex/utils.py
@@ -1,5 +1,5 @@
-"""Contain utility functions
-"""
+"""Contain utility functions"""
+
import binascii
import copy
from importlib import import_module
@@ -12,7 +12,19 @@
import urllib.request
import warnings
-from typing import Any, Callable, Generic, Optional, TypeVar, TYPE_CHECKING, Union, cast, overload, Literal
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ Optional,
+ TypeVar,
+ TYPE_CHECKING,
+ Union,
+ cast,
+ overload,
+ Literal,
+)
+
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
else:
@@ -33,47 +45,67 @@
# register_cmap is deprecated in matplotlib > 3.7.0 and replaced by colormaps.register
try:
from matplotlib import colormaps as cm
+
def register_cmap(cmap):
return cm.register(cmap)
except ImportError:
from matplotlib.cm import register_cmap
-P = ParamSpec('P')
-T = TypeVar('T')
+P = ParamSpec("P")
+T = TypeVar("T")
+
class DocLoader(Generic[P, T]):
- def __init__(self, func, mod, package, actual_func: Optional[Callable[P, T]] = None):
- self._load: Callable[[], Callable[P, T]] = lambda: getattr(import_module(mod, package), func)
- self._actual_func = actual_func # stored only to resolve generic types during type checking
+ def __init__(
+ self, func, mod, package, actual_func: Optional[Callable[P, T]] = None
+ ):
+ self._load: Callable[[], Callable[P, T]] = lambda: getattr(
+ import_module(mod, package), func
+ )
+ self._actual_func = (
+ actual_func # stored only to resolve generic types during type checking
+ )
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self._load()(*args, **kwargs)
@overload
- def __getattribute__(self, name: Literal['_load']) -> Callable[P, T]: ...
+ def __getattribute__(self, name: Literal["_load"]) -> Callable[P, T]: ...
@overload
def __getattribute__(self, name: str) -> Any: ...
- def __getattribute__(self, name: Union[Literal['_load'], str]) -> Union[Any, Callable[P, T]]:
+ def __getattribute__(
+ self, name: Union[Literal["_load"], str]
+ ) -> Union[Any, Callable[P, T]]:
if name != "_load":
return getattr(self._load(), name)
else:
return cast(Callable[P, T], object.__getattribute__(self, name))
+
if TYPE_CHECKING:
from cortex.mapper import get_mapper as _get_mapper
else:
_get_mapper = None
get_mapper = DocLoader("get_mapper", ".mapper", "cortex", actual_func=_get_mapper)
+
def get_roipack(*args, **kwargs):
- warnings.warn('Please use db.get_overlay instead', DeprecationWarning)
+ warnings.warn("Please use db.get_overlay instead", DeprecationWarning)
return db.get_overlay(*args, **kwargs)
-def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=False,
- decimate=False, external_svg=None,
- overlays_available=None):
+
+def get_ctmpack(
+ subject,
+ types=("inflated",),
+ method="raw",
+ level=0,
+ recache=False,
+ decimate=False,
+ external_svg=None,
+ overlays_available=None,
+):
"""Creates ctm file for the specified input arguments.
This is a cached file that specifies (1) the surfaces between which
@@ -108,12 +140,10 @@ def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=Fal
-------
ctmfile :
"""
- lvlstr = ("%dd" if decimate else "%d")%level
+ lvlstr = ("%dd" if decimate else "%d") % level
# Generates different cache files for each combination of disp_layers
- ctmcache = "%s_[{types}]_{method}_{level}_v3.json"%subject
- ctmcache = ctmcache.format(types=','.join(types),
- method=method,
- level=lvlstr)
+ ctmcache = "%s_[{types}]_{method}_{level}_v3.json" % subject
+ ctmcache = ctmcache.format(types=",".join(types), method=method, level=lvlstr)
ctmfile = os.path.join(db.get_cache(subject), ctmcache)
if os.path.exists(ctmfile) and not recache:
@@ -121,20 +151,23 @@ def get_ctmpack(subject, types=("inflated",), method="raw", level=0, recache=Fal
print("Generating new ctm file...")
from . import brainctm
- ptmap = brainctm.make_pack(ctmfile,
- subject,
- types=types,
- method=method,
- level=level,
- decimate=decimate,
- external_svg=external_svg,
- overlays_available=overlays_available)
+
+ ptmap = brainctm.make_pack(
+ ctmfile,
+ subject,
+ types=types,
+ method=method,
+ level=level,
+ decimate=decimate,
+ external_svg=external_svg,
+ overlays_available=overlays_available,
+ )
return ctmfile
def get_ctmmap(subject, **kwargs):
"""Return a mapping from the vertices in the CTM surface to the vertices
- in the freesurfer surface.
+ in the freesurfer surface.
The mapping is a numpy array, such that `ctm2fs_left[i] = j` means that the
i-th vertex in the CTM surface corresponds to the j-th vertex in the freesurfer
surface.
@@ -162,8 +195,9 @@ def get_ctmmap(subject, **kwargs):
from scipy.spatial import cKDTree
from . import brainctm
+
jsfile = get_ctmpack(subject, **kwargs)
- ctmfile = os.path.splitext(jsfile)[0]+".ctm"
+ ctmfile = os.path.splitext(jsfile)[0] + ".ctm"
# Load freesurfer surfaces
try:
@@ -209,6 +243,7 @@ def get_ctm2webgl_map(subject, **kwargs):
maximum length of 65535.
"""
from . import brainctm
+
# Load CTM surfaces
jsonfile = get_ctmpack(subject, **kwargs)
ctmfile = os.path.splitext(jsonfile)[0] + ".ctm"
@@ -290,7 +325,7 @@ def get_fs2webgl_map(subject, **kwargs):
return fs2webgl_left, fs2webgl_right
-def get_cortical_mask(subject, xfmname, type='nearest'):
+def get_cortical_mask(subject, xfmname, type="nearest"):
"""Gets the cortical mask for a particular transform
Parameters
@@ -301,12 +336,12 @@ def get_cortical_mask(subject, xfmname, type='nearest'):
Transform name
type : str
Mask type, one of {"cortical", "thin", "thick", "nearest", "line_nearest"}.
- - 'cortical' includes voxels contained within the cortical ribbon,
- between the freesurfer-estimated white matter and pial surfaces.
- - 'thin' includes voxels that are < 2mm away from the fiducial surface.
+ - 'cortical' includes voxels contained within the cortical ribbon,
+ between the freesurfer-estimated white matter and pial surfaces.
+ - 'thin' includes voxels that are < 2mm away from the fiducial surface.
- 'thick' includes voxels that are < 8mm away from the fiducial surface.
- 'nearest' includes only the voxels overlapping the fiducial surface.
- - 'line_nearest' includes all voxels that have any part within the cortical
+ - 'line_nearest' includes all voxels that have any part within the cortical
ribbon.
Returns
@@ -316,13 +351,13 @@ def get_cortical_mask(subject, xfmname, type='nearest'):
Notes
-----
- "nearest" is a conservative "cortical" mask, while "line_nearest" is a liberal
+ "nearest" is a conservative "cortical" mask, while "line_nearest" is a liberal
"cortical" mask.
"""
- if type == 'cortical':
+ if type == "cortical":
ppts, polys = db.get_surf(subject, "pia", merge=True, nudge=False)
wpts, polys = db.get_surf(subject, "wm", merge=True, nudge=False)
- thickness = np.sqrt(((ppts - wpts)**2).sum(1))
+ thickness = np.sqrt(((ppts - wpts) ** 2).sum(1))
dist, idx = get_vox_dist(subject, xfmname)
cortex = np.zeros(dist.shape, dtype=bool)
@@ -331,9 +366,9 @@ def get_cortical_mask(subject, xfmname, type='nearest'):
mask = idx == vert
cortex[mask] = dist[mask] <= thickness[vert]
if i % 100 == 0:
- print("%0.3f%%"%(i/float(len(verts)) * 100))
+ print("%0.3f%%" % (i / float(len(verts)) * 100))
return cortex
- elif type in ('thick', 'thin'):
+ elif type in ("thick", "thin"):
dist, idx = get_vox_dist(subject, xfmname)
return dist < dict(thick=8, thin=2)[type]
else:
@@ -379,8 +414,8 @@ def get_vox_dist(subject, xfmname, surface="fiducial", max_dist=np.inf):
return dist.T, argdist.T
-def get_hemi_masks(subject, xfmname, type='nearest'):
- '''Returns a binary mask of the left and right hemisphere
+def get_hemi_masks(subject, xfmname, type="nearest"):
+ """Returns a binary mask of the left and right hemisphere
surface voxels for the given subject.
Parameters
@@ -394,11 +429,13 @@ def get_hemi_masks(subject, xfmname, type='nearest'):
Returns
-------
- '''
+ """
return get_mapper(subject, xfmname, type=type).hemimasks
-def add_roi(data, name="new_roi", open_inkscape=True, add_path=True,
- overlay_file=None, **kwargs):
+
+def add_roi(
+ data, name="new_roi", open_inkscape=True, add_path=True, overlay_file=None, **kwargs
+):
"""Add new flatmap image to the ROI file for a subject.
(The subject is specified in creation of the data object)
@@ -440,14 +477,16 @@ def add_roi(data, name="new_roi", open_inkscape=True, add_path=True,
svg = db.get_overlay(dv.subject, overlay_file=overlay_file)
fp = io.BytesIO()
- quickflat.make_png(fp, dv, height=1024, with_rois=False, with_labels=False, **kwargs)
+ quickflat.make_png(
+ fp, dv, height=1024, with_rois=False, with_labels=False, **kwargs
+ )
fp.seek(0)
- svg.rois.add_shape(name, binascii.b2a_base64(fp.read()).decode('utf-8'), add_path)
+ svg.rois.add_shape(name, binascii.b2a_base64(fp.read()).decode("utf-8"), add_path)
if open_inkscape:
- inkscape_cmd = config.get('dependency_paths', 'inkscape')
- if LooseVersion(INKSCAPE_VERSION) < LooseVersion('1.0'):
- cmd = [inkscape_cmd, '-f', svg.svgfile]
+ inkscape_cmd = config.get("dependency_paths", "inkscape")
+ if LooseVersion(INKSCAPE_VERSION) < LooseVersion("1.0"):
+ cmd = [inkscape_cmd, "-f", svg.svgfile]
else:
cmd = [inkscape_cmd, svg.svgfile]
return sp.call(cmd)
@@ -463,6 +502,41 @@ def _get_neighbors_dict(polys):
return neighbors_dict
+def get_contour_vertices(data, subject, surface="fiducial"):
+ """Find vertices at borders of parcellation labels.
+
+ A vertex is a border vertex if any of its mesh neighbors has a different
+ label value. This is useful for drawing contour lines around parcellation
+ regions on the cortical surface.
+
+ Parameters
+ ----------
+ data : array_like, shape (n_vertices,)
+ Label values per vertex (e.g., parcellation integers).
+ subject : str
+ Subject name in the pycortex database.
+ surface : str
+ Surface type for adjacency computation (default: 'fiducial').
+
+ Returns
+ -------
+ border_mask : ndarray, shape (n_vertices,), dtype bool
+ True at border vertices, False elsewhere.
+ """
+ _, polys = db.get_surf(subject, surface, merge=True)
+ neighbors = _get_neighbors_dict(polys)
+ data = np.asarray(data)
+ border = np.zeros(len(data), dtype=bool)
+ for v, neighs in neighbors.items():
+ if v >= len(data):
+ continue
+ for n in neighs:
+ if n < len(data) and data[v] != data[n]:
+ border[v] = True
+ break
+ return border
+
+
def get_roi_verts(subject, roi=None, mask=False, overlay_file=None):
"""Return vertices for the given ROIs, or all ROIs if none are given.
@@ -509,8 +583,8 @@ def get_roi_verts(subject, roi=None, mask=False, overlay_file=None):
for name in roi:
roi_idx = np.intersect1d(svg.rois.get_mask(name), goodpts)
- # Now we want to include also the vertices that were removed from the flat
- # surface that is, for every vertex in roi_idx we want to add the pts that are
+ # Now we want to include also the vertices that were removed from the flat
+ # surface that is, for every vertex in roi_idx we want to add the pts that are
# not in goodpts but that are in pts_full
# to do that, we need to find the neighboring indices from polys_full
extra_idx = set()
@@ -568,7 +642,7 @@ def get_roi_surf(subject, surf_type, roi, overlay_file=None):
return pts[vert_idx], np.array(reindexed_polys)
-def get_roi_mask(subject, xfmname, roi=None, projection='nearest'):
+def get_roi_mask(subject, xfmname, roi=None, projection="nearest"):
"""Return a mask for the given ROI(s)
Deprecated - use get_roi_masks()
@@ -588,15 +662,15 @@ def get_roi_mask(subject, xfmname, roi=None, projection='nearest'):
output : dict
Dict of ROIs and their masks
"""
- warnings.warn('Deprecated! Use get_roi_masks')
+ warnings.warn("Deprecated! Use get_roi_masks")
mapper = get_mapper(subject, xfmname, type=projection)
rois = get_roi_verts(subject, roi=roi, mask=True)
output = dict()
for name, verts in list(rois.items()):
# This is broken; unclear when/if backward mappers ever worked this way.
- #left, right = mapper.backwards(vert_mask)
- #output[name] = left + right
+ # left, right = mapper.backwards(vert_mask)
+ # output[name] = left + right
output[name] = mapper.backwards(verts.astype(float))
# Threshold?
return output
@@ -644,6 +718,7 @@ def get_aseg_mask(subject, aseg_name, xfmname=None, order=1, threshold=None, **k
"""
from .freesurfer import fs_aseg_dict
+
aseg = db.get_anat(subject, type="aseg").get_fdata().T
if not isinstance(aseg_name, (list, tuple)):
@@ -652,24 +727,36 @@ def get_aseg_mask(subject, aseg_name, xfmname=None, order=1, threshold=None, **k
mask = np.zeros(aseg.shape)
for name in aseg_name:
if name in fs_aseg_dict:
- tmp = aseg==fs_aseg_dict[name]
+ tmp = aseg == fs_aseg_dict[name]
else:
# Combine all masks containing `name` (e.g. all masks with 'cerebellum' in the name)
keys = [k for k in fs_aseg_dict.keys() if name.lower() in k.lower()]
if len(keys) == 0:
- raise ValueError('Unknown aseg_name!')
- tmp = np.any(np.array([aseg==fs_aseg_dict[k] for k in keys]), axis=0)
+ raise ValueError("Unknown aseg_name!")
+ tmp = np.any(np.array([aseg == fs_aseg_dict[k] for k in keys]), axis=0)
mask = np.logical_or(mask, tmp)
if xfmname is not None:
- mask = anat2epispace(mask.astype(float), subject, xfmname, order=order, **kwargs)
+ mask = anat2epispace(
+ mask.astype(float), subject, xfmname, order=order, **kwargs
+ )
if threshold is not None:
mask = mask > threshold
return mask
-def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_lr=False,
- allow_overlap=False, fail_for_missing_rois=True, exclude_empty_rois=False,
- threshold=None, return_dict=True, overlay_file=None):
+def get_roi_masks(
+ subject,
+ xfmname,
+ roi_list=None,
+ gm_sampler="cortical",
+ split_lr=False,
+ allow_overlap=False,
+ fail_for_missing_rois=True,
+ exclude_empty_rois=False,
+ threshold=None,
+ return_dict=True,
+ overlay_file=None,
+):
"""Return a dictionary of roi masks
This function returns a single 3D array with a separate numerical index for each ROI,
@@ -748,13 +835,17 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
'thin' as your `gm_sampler`.
"""
# Convert mapper names to pycortex sampler types
- mapper_dict = {'cortical-conservative':'nearest',
- 'cortical-liberal':'line_nearest'}
+ mapper_dict = {
+ "cortical-conservative": "nearest",
+ "cortical-liberal": "line_nearest",
+ }
# Method
use_mapper = gm_sampler in mapper_dict
- use_cortex_mask = (gm_sampler in ('cortical', 'thick', 'thin')) or not isinstance(gm_sampler, str)
+ use_cortex_mask = (gm_sampler in ("cortical", "thick", "thin")) or not isinstance(
+ gm_sampler, str
+ )
if not (use_mapper or use_cortex_mask):
- raise ValueError('Unknown gray matter sampler (gm_sampler)!')
+ raise ValueError("Unknown gray matter sampler (gm_sampler)!")
# Initialize
roi_voxels = {}
pct_coverage = {}
@@ -773,18 +864,28 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
roi_verts = get_roi_verts(subject, mask=use_mapper, overlay_file=overlay_file)
roi_list = list(roi_verts.keys())
else:
- tmp_list = [r for r in roi_list if not r=='Cortex']
+ tmp_list = [r for r in roi_list if not r == "Cortex"]
try:
- roi_verts = get_roi_verts(subject, roi=tmp_list, mask=use_mapper, overlay_file=overlay_file)
+ roi_verts = get_roi_verts(
+ subject, roi=tmp_list, mask=use_mapper, overlay_file=overlay_file
+ )
except KeyError as key:
if fail_for_missing_rois:
- raise KeyError("Requested ROI {} not found in overlays.svg!".format(key))
+ raise KeyError(
+ "Requested ROI {} not found in overlays.svg!".format(key)
+ )
else:
- roi_verts = get_roi_verts(subject, roi=None, mask=use_mapper, overlay_file=overlay_file)
- missing = [r for r in roi_list if not r in roi_verts.keys()+['Cortex']]
- roi_verts = dict((roi, verts) for roi, verts in roi_verts.items() if roi in roi_list)
- roi_list = list(set(roi_list)-set(missing))
- print('Requested ROI(s) {} not found in overlays.svg!'.format(missing))
+ roi_verts = get_roi_verts(
+ subject, roi=None, mask=use_mapper, overlay_file=overlay_file
+ )
+ missing = [
+ r for r in roi_list if not r in roi_verts.keys() + ["Cortex"]
+ ]
+ roi_verts = dict(
+ (roi, verts) for roi, verts in roi_verts.items() if roi in roi_list
+ )
+ roi_list = list(set(roi_list) - set(missing))
+ print("Requested ROI(s) {} not found in overlays.svg!".format(missing))
# Get (a) indices for nearest vertex to each voxel
# and (b) distance from each voxel to nearest vertex in fiducial surface
if (use_cortex_mask or split_lr) or (not return_dict):
@@ -799,7 +900,7 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
# Loop over ROIs to map vertices to volume, using mapper or cortex mask + vertex indices
for roi in roi_list:
if roi not in roi_verts:
- if not roi=='Cortex':
+ if not roi == "Cortex":
print("ROI {} not found...".format(roi))
continue
if use_mapper:
@@ -808,10 +909,14 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
if threshold is not None:
roi_voxels[roi] = roi_voxels[roi] > threshold
# Check for partial / empty rois:
- vert_in_scan = np.hstack([np.array((m>0).sum(1)).flatten() for m in mapper.masks])
+ vert_in_scan = np.hstack(
+ [np.array((m > 0).sum(1)).flatten() for m in mapper.masks]
+ )
vert_in_scan = vert_in_scan[roi_verts[roi]]
elif use_cortex_mask:
- vox_in_roi = np.in1d(vox_idx.flatten(), roi_verts[roi]).reshape(vox_idx.shape)
+ vox_in_roi = np.in1d(vox_idx.flatten(), roi_verts[roi]).reshape(
+ vox_idx.shape
+ )
roi_voxels[roi] = vox_in_roi & cortex_mask
# This is not accurate... because vox_idx only contains the indices of the *nearest*
# vertex to each voxel, it excludes many vertices. I can't think of a way to compute
@@ -820,58 +925,69 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
# Compute ROI coverage
pct_coverage[roi] = vert_in_scan.mean() * 100
if use_mapper:
- print("Found %0.2f%% of %s"%(pct_coverage[roi], roi))
+ print("Found %0.2f%% of %s" % (pct_coverage[roi], roi))
# Create cortex mask
all_mask = np.array(list(roi_voxels.values())).sum(0)
- if 'Cortex' in roi_list:
+ if "Cortex" in roi_list:
if use_mapper:
# cortex_mask isn't defined / exactly definable if you're using a mapper
- print("Cortex roi not included b/c currently not compatible with your selection for gm_sampler")
- _ = roi_list.pop(roi_list.index('Cortex'))
+ print(
+ "Cortex roi not included b/c currently not compatible with your selection for gm_sampler"
+ )
+ _ = roi_list.pop(roi_list.index("Cortex"))
else:
- roi_voxels['Cortex'] = (all_mask==0) & cortex_mask
+ roi_voxels["Cortex"] = (all_mask == 0) & cortex_mask
# Optionally cull voxels assigned to > 1 ROI due to partly overlapping ROI splines
# in inkscape overlays.svg file:
if not allow_overlap:
- print('Cutting {} overlapping voxels (should be < ~50)'.format(np.sum(all_mask > 1)))
+ print(
+ "Cutting {} overlapping voxels (should be < ~50)".format(
+ np.sum(all_mask > 1)
+ )
+ )
for roi in roi_list:
roi_voxels[roi][all_mask > 1] = False
# Split left / right hemispheres if desired
if split_lr:
# Use the fiducial surface because we need to have all vertices
- left_verts, _ = db.get_surf(subject, "fiducial", merge=False, nudge=True)
+ left_verts, _ = db.get_surf(subject, "fiducial", merge=False, nudge=True)
left_mask = vox_idx < len(np.unique(left_verts[1]))
right_mask = np.logical_not(left_mask)
roi_voxels_lr = {}
for roi in roi_list:
# roi_voxels may contain float values if using a mapper, therefore we need
# to manually set the voxels in the other hemisphere to False. Then we let
- # numpy do the conversion False -> 0.
- roi_voxels_lr[roi + '_L'] = copy.copy(roi_voxels[roi])
- roi_voxels_lr[roi + '_L'][right_mask] = False
- roi_voxels_lr[roi + '_R'] = copy.copy(roi_voxels[roi])
- roi_voxels_lr[roi + '_R'][left_mask] = False
+ # numpy do the conversion False -> 0.
+ roi_voxels_lr[roi + "_L"] = copy.copy(roi_voxels[roi])
+ roi_voxels_lr[roi + "_L"][right_mask] = False
+ roi_voxels_lr[roi + "_R"] = copy.copy(roi_voxels[roi])
+ roi_voxels_lr[roi + "_R"][left_mask] = False
output = roi_voxels_lr
else:
output = roi_voxels
# Check percent coverage / optionally cull empty ROIs
- for roi in set(roi_list)-set(['Cortex']):
+ for roi in set(roi_list) - set(["Cortex"]):
if pct_coverage[roi] < 100:
# if not np.any(mask) : reject ROI
- if pct_coverage[roi]==0:
- warnings.warn('ROI %s is entirely missing from your scan protocol!'%(roi))
+ if pct_coverage[roi] == 0:
+ warnings.warn(
+ "ROI %s is entirely missing from your scan protocol!" % (roi)
+ )
if exclude_empty_rois:
if split_lr:
- _ = output.pop(roi+'_L')
- _ = output.pop(roi+'_R')
+ _ = output.pop(roi + "_L")
+ _ = output.pop(roi + "_R")
else:
_ = output.pop(roi)
else:
# I think this is the only one for which this works correctly...
- if gm_sampler=='cortical-conservative':
- warnings.warn('ROI %s is only %0.2f%% contained in your scan protocol!'%(roi, pct_coverage[roi]))
+ if gm_sampler == "cortical-conservative":
+ warnings.warn(
+ "ROI %s is only %0.2f%% contained in your scan protocol!"
+ % (roi, pct_coverage[roi])
+ )
# Support alternative outputs for backward compatibility
if return_dict:
@@ -885,6 +1001,7 @@ def get_roi_masks(subject, xfmname, roi_list=None, gm_sampler='cortical', split_
idx_vol[left_mask] *= -1
return idx_vol, idx_labels
+
def get_dropout(subject: str, xfmname: str, power: float = 20):
"""Create a dropout Volume showing where EPI signal
is very low.
@@ -909,13 +1026,15 @@ def get_dropout(subject: str, xfmname: str, power: float = 20):
if rawdata.ndim > 3:
rawdata = rawdata.mean(0)
- rawdata[rawdata==0] = np.mean(rawdata[rawdata!=0])
+ rawdata[rawdata == 0] = np.mean(rawdata[rawdata != 0])
normdata = (rawdata - rawdata.min()) / (rawdata.max() - rawdata.min())
normdata = (1 - normdata) ** power
from .dataset import Volume
+
return Volume(normdata, subject, xfmname)
+
def make_movie(stim, outfile, fps=15, size="640x480"):
"""Makes an .ogv movie
@@ -939,10 +1058,12 @@ def make_movie(stim, outfile, fps=15, size="640x480"):
"""
import shlex
import subprocess as sp
+
cmd = "ffmpeg -r {fps} -i {infile} -b 4800k -g 30 -s {size} -vcodec libtheora {outfile}.ogv"
fcmd = cmd.format(infile=stim, size=size, fps=fps, outfile=outfile)
sp.call(shlex.split(fcmd))
+
def vertex_to_voxel(subject): # Am I deprecated in favor of mappers??? Maybe?
"""
Parameters
@@ -975,28 +1096,30 @@ def vertex_to_voxel(subject): # Am I deprecated in favor of mappers??? Maybe?
def _set_edge_distance_graph_attribute(graph, pts, polys):
- '''
+ """
adds the attribute 'edge distance' to a graph
- '''
+ """
import networkx as nx
l2_distance = lambda v1, v2: np.linalg.norm(pts[v1] - pts[v2])
- heuristic = l2_distance # A* heuristic
+ heuristic = l2_distance # A* heuristic
- if not nx.get_edge_attributes(graph, 'distance'): # Add edge distances as an attribute to this graph if it isn't there
+ if not nx.get_edge_attributes(
+ graph, "distance"
+ ): # Add edge distances as an attribute to this graph if it isn't there
edge_distances = dict()
- for x,y,z in polys:
- edge_distances[(x,y)] = heuristic(x,y)
- edge_distances[(y,x)] = heuristic(y,x)
- edge_distances[(y,z)] = heuristic(y,z)
- edge_distances[(z,y)] = heuristic(z,y)
- edge_distances[(x,z)] = heuristic(x,z)
- edge_distances[(z,x)] = heuristic(z,x)
- nx.set_edge_attributes(graph, edge_distances, name='distance')
+ for x, y, z in polys:
+ edge_distances[(x, y)] = heuristic(x, y)
+ edge_distances[(y, x)] = heuristic(y, x)
+ edge_distances[(y, z)] = heuristic(y, z)
+ edge_distances[(z, y)] = heuristic(z, y)
+ edge_distances[(x, z)] = heuristic(x, z)
+ edge_distances[(z, x)] = heuristic(z, x)
+ nx.set_edge_attributes(graph, edge_distances, name="distance")
def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True):
- '''Return voxels that are shared by multiple vertices, and for each such voxel,
+ """Return voxels that are shared by multiple vertices, and for each such voxel,
also returns the mutually farthest pair of vertices mapping to the voxel
Parameters
----------
@@ -1017,18 +1140,21 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True)
vox_vert_array: np.array,
array of dimensions # voxels X 3, columns being: (vox_idx, farthest_pair[0],
farthest_pair[1])
- '''
+ """
import networkx as nx
from scipy.sparse import find as sparse_find
- Lmask, Rmask = get_mapper(subject, xfmname).masks # Get masks for left and right hemisphere
- if hemi == 'both':
- hemispheres = ['lh', 'rh']
+
+ Lmask, Rmask = get_mapper(
+ subject, xfmname
+ ).masks # Get masks for left and right hemisphere
+ if hemi == "both":
+ hemispheres = ["lh", "rh"]
else:
hemispheres = [hemi]
out = []
for hem in hemispheres:
- if hem == 'lh':
+ if hem == "lh":
mask = Lmask
else:
mask = Rmask
@@ -1036,8 +1162,10 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True)
all_voxels = mask.tolil().transpose().rows # Map from voxels to verts
vert_to_vox_map = dict(zip(*(sparse_find(mask)[:2]))) # From verts to vox
- pts_fid, polys_fid = db.get_surf(subject, 'fiducial', hem) # Get the fiducial surface
- surf = Surface(pts_fid, polys_fid) #Get the fiducial surface
+ pts_fid, polys_fid = db.get_surf(
+ subject, "fiducial", hem
+ ) # Get the fiducial surface
+ surf = Surface(pts_fid, polys_fid) # Get the fiducial surface
graph = surf.graph
_set_edge_distance_graph_attribute(graph, pts_fid, polys_fid)
@@ -1046,32 +1174,44 @@ def get_shared_voxels(subject, xfmname, hemi="both", merge=True, use_astar=True)
heuristic = l2_distance # A* heuristic
if use_astar:
- shortest_path = lambda a, b: nx.astar_path(graph, a, b, heuristic=heuristic, weight='distance') # Find approximate shortest paths using A* search
+ shortest_path = lambda a, b: nx.astar_path(
+ graph, a, b, heuristic=heuristic, weight="distance"
+ ) # Find approximate shortest paths using A* search
else:
- shortest_path = surf.geodesic_path # Find shortest paths using geodesic distances
+ shortest_path = (
+ surf.geodesic_path
+ ) # Find shortest paths using geodesic distances
vox_vert_list = []
for vox_idx, vox in enumerate(all_voxels):
if len(vox) > 1: # If the voxel maps to multiple vertices
vox = np.array(vox).astype(int)
- for v1 in range(vox.size-1):
+ for v1 in range(vox.size - 1):
vert1 = vox[v1]
if vert1 in vert_to_vox_map: # If the vertex is a valid vertex
- for v2 in range(v1+1, vox.size):
+ for v2 in range(v1 + 1, vox.size):
vert2 = vox[v2]
- if vert2 in vert_to_vox_map: # If the vertex is a valid vertex
+ if (
+ vert2 in vert_to_vox_map
+ ): # If the vertex is a valid vertex
path = shortest_path(vert1, vert2)
# Test whether any vertex in path goes out of the voxel
- stays_in_voxel = all([(v in vert_to_vox_map) and (vert_to_vox_map[v] == vox_idx) for v in path])
+ stays_in_voxel = all(
+ [
+ (v in vert_to_vox_map)
+ and (vert_to_vox_map[v] == vox_idx)
+ for v in path
+ ]
+ )
if not stays_in_voxel:
vox_vert_list.append([vox_idx, vert1, vert2])
- tmp = np.array(vox_vert_list)
+ tmp = np.array(vox_vert_list)
# Add offset for right hem voxels
- if hem=='rh':
+ if hem == "rh":
tmp[:, 1:3] += Lmask.shape[0]
out.append(tmp)
- if hemi in ('lh', 'rh'):
+ if hemi in ("lh", "rh"):
return out[0]
else:
if merge:
@@ -1096,13 +1236,18 @@ def load_sparse_array(fname, varname):
conventions, so cannot be used to load arbitrary sparse arrays.
"""
import scipy.sparse
+
with h5py.File(fname) as hf:
- data = (hf['%s_data'%varname], hf['%s_indices'%varname], hf['%s_indptr'%varname])
- sparsemat = scipy.sparse.csr_matrix(data, shape=hf['%s_shape'%varname])
+ data = (
+ hf["%s_data" % varname],
+ hf["%s_indices" % varname],
+ hf["%s_indptr" % varname],
+ )
+ sparsemat = scipy.sparse.csr_matrix(data, shape=hf["%s_shape" % varname])
return sparsemat
-def save_sparse_array(fname, data, varname, mode='a'):
+def save_sparse_array(fname, data, varname, mode="a"):
"""Save a numpy sparse array to an hdf file
Results in relatively smaller file size than numpy.savez
@@ -1119,19 +1264,20 @@ def save_sparse_array(fname, data, varname, mode='a'):
write / append mode set, one of ['w','a'] (passed to h5py.File())
"""
import scipy.sparse
+
if not isinstance(data, scipy.sparse.csr.csr_matrix):
data_ = scipy.sparse.csr_matrix(data)
else:
data_ = data
with h5py.File(fname, mode=mode) as hf:
# Save indices
- hf.create_dataset(varname + '_indices', data=data_.indices, compression='gzip')
+ hf.create_dataset(varname + "_indices", data=data_.indices, compression="gzip")
# Save data
- hf.create_dataset(varname + '_data', data=data_.data, compression='gzip')
+ hf.create_dataset(varname + "_data", data=data_.data, compression="gzip")
# Save indptr
- hf.create_dataset(varname + '_indptr', data=data_.indptr, compression='gzip')
+ hf.create_dataset(varname + "_indptr", data=data_.indptr, compression="gzip")
# Save shape
- hf.create_dataset(varname + '_shape', data=data_.shape, compression='gzip')
+ hf.create_dataset(varname + "_shape", data=data_.shape, compression="gzip")
def get_cmap(name):
@@ -1149,10 +1295,11 @@ def get_cmap(name):
"""
import matplotlib.pyplot as plt
from matplotlib import colors
+
# unknown colormap, test whether it's in pycortex colormaps
- cmapdir = config.get('webgl', 'colormaps')
+ cmapdir = config.get("webgl", "colormaps")
colormaps = os.listdir(cmapdir)
- colormaps = sorted([c for c in colormaps if '.png' in c])
+ colormaps = sorted([c for c in colormaps if ".png" in c])
colormaps = dict((c[:-4], os.path.join(cmapdir, c)) for c in colormaps)
if name in colormaps:
I = plt.imread(colormaps[name])
@@ -1165,15 +1312,16 @@ def get_cmap(name):
try:
cmap = plt.cm.get_cmap(name)
except:
- raise Exception('Unkown color map!')
+ raise Exception("Unknown color map!")
return cmap
+
def add_cmap(cmap, name, cmapdir=None):
"""Add a colormap to pycortex.
This stores a matplotlib colormap in the pycortex filestore, such that it can
- be used in the webgl viewer in pycortex. See
- https://matplotlib.org/stable/users/explain/colors/colormap-manipulation.html
+ be used in the webgl viewer in pycortex. See
+ https://matplotlib.org/stable/users/explain/colors/colormap-manipulation.html
for more information about how to generate colormaps in matplotlib.
Parameters
@@ -1182,8 +1330,8 @@ def add_cmap(cmap, name, cmapdir=None):
Color map to be saved
name : str
Name for colormap, e.g. 'jet', 'blue_to_yellow', etc. The name will be used
- to generate a filename for the colormap stored in the pycortex store,
- so avoid illegal characters for a filename. This name will also be used to
+ to generate a filename for the colormap stored in the pycortex store,
+ so avoid illegal characters for a filename. This name will also be used to
specify this colormap in future calls to `cortex.quickflat.make_figure()`
or `cortex.webgl.show()`.
"""
@@ -1199,8 +1347,9 @@ def add_cmap(cmap, name, cmapdir=None):
plt.imsave(os.path.join(cmapdir, name), cmap_im, format="png")
-def download_subject(subject_id='fsaverage', url=None, pycortex_store=None,
- download_again=False):
+def download_subject(
+ subject_id="fsaverage", url=None, pycortex_store=None, download_again=False
+):
"""Download subjects to pycortex store
Parameters
@@ -1224,15 +1373,16 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None,
warnings.warn(
"{} is already present in the database. "
"Set download_again to True if you wish to download "
- "the subject again.".format(subject_id))
+ "the subject again.".format(subject_id)
+ )
return
# Map codes to URLs; more coming eventually
id_to_url = dict(
- fsaverage='https://ndownloader.figshare.com/files/17827577?private_link=4871247dce31e188e758',
+ fsaverage="https://ndownloader.figshare.com/files/17827577?private_link=4871247dce31e188e758",
)
if url is None:
if subject_id not in id_to_url:
- raise ValueError('Unknown subject_id!')
+ raise ValueError("Unknown subject_id!")
url = id_to_url[subject_id]
# Setup pycortex store location
if pycortex_store is None:
@@ -1242,12 +1392,11 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None,
# Download to temp dir
print("Downloading from: {}".format(url))
with tempfile.TemporaryDirectory() as tmp_dir:
- print('Downloading subject {} to {}'.format(subject_id, tmp_dir))
+ print("Downloading subject {} to {}".format(subject_id, tmp_dir))
fnout, _ = urllib.request.urlretrieve(
- url,
- os.path.join(tmp_dir, f"{subject_id}.tar.gz")
+ url, os.path.join(tmp_dir, f"{subject_id}.tar.gz")
)
- print(f'Done downloading to {fnout}')
+ print(f"Done downloading to {fnout}")
# Un-tar to pycortex store
with tarfile.open(fnout, "r:gz") as tar:
print("Extracting subject {} to {}".format(subject_id, pycortex_store))
@@ -1259,51 +1408,56 @@ def download_subject(subject_id='fsaverage', url=None, pycortex_store=None,
def rotate_flatmap(surf_id, theta, plot=False):
"""Rotate flatmap to be less V-shaped
-
+
Parameters
----------
surf_id : str
pycortex surface identifier
theta : scalar
- angle in degrees to rotate flatmaps (rotation is clockwise
+ angle in degrees to rotate flatmaps (rotation is clockwise
for right hemisphere and counter-clockwise for left)
plot : bool
Whether to make a coarse plot to visualize the changes
"""
# Lazy load of matplotlib
import matplotlib.pyplot as plt
- paths = db.get_paths(surf_id)['surfs']['flat']
+
+ paths = db.get_paths(surf_id)["surfs"]["flat"]
theta = np.radians(theta)
if plot:
fig, axs = plt.subplots(2, 2)
- for j, hem in enumerate(('lh','rh')):
+ for j, hem in enumerate(("lh", "rh")):
this_file = paths[hem]
pts, polys = formats.read_gii(this_file)
# Rotate clockwise (- rotation) for RH, counter-clockwise (+ rotation) for LH
- if hem == 'rh':
- rtheta = - theta
+ if hem == "rh":
+ rtheta = -theta
else:
rtheta = copy.copy(theta)
- rotation_mat = np.array([[np.cos(rtheta), -np.sin(rtheta)], [np.sin(rtheta), np.cos(rtheta)]])
+ rotation_mat = np.array(
+ [[np.cos(rtheta), -np.sin(rtheta)], [np.sin(rtheta), np.cos(rtheta)]]
+ )
rotated = rotation_mat.dot(pts[:, :2].T).T
pts_new = pts.copy()
pts_new[:, :2] = rotated
new_file, bkup_num = copy.copy(this_file), 0
while os.path.exists(new_file):
- new_file = this_file.replace('.gii', '_rotbkup%02d.gii'%bkup_num)
+ new_file = this_file.replace(".gii", "_rotbkup%02d.gii" % bkup_num)
bkup_num += 1
- print('Backing up file at %s...' % new_file)
+ print("Backing up file at %s..." % new_file)
shutil.copy(this_file, new_file)
formats.write_gii(this_file, pts_new, polys)
- print('Overwriting %s...' % this_file)
+ print("Overwriting %s..." % this_file)
if plot:
- axs[0,j].plot(*pts[::100, :2].T, marker='r.')
- axs[0,j].axis('equal')
- axs[1,j].plot(*pts_new[::100, :2].T, marker='b.')
- axs[1,j].axis('equal')
+ axs[0, j].plot(*pts[::100, :2].T, marker="r.")
+ axs[0, j].axis("equal")
+ axs[1, j].plot(*pts_new[::100, :2].T, marker="b.")
+ axs[1, j].axis("equal")
# Remove and back up overlays file
- overlay_file = db.get_paths(surf_id)['overlays']
- shutil.copy(overlay_file, overlay_file.replace('.svg', '_rotbkup%02d.svg'%bkup_num))
+ overlay_file = db.get_paths(surf_id)["overlays"]
+ shutil.copy(
+ overlay_file, overlay_file.replace(".svg", "_rotbkup%02d.svg" % bkup_num)
+ )
os.unlink(overlay_file)
# Regenerate file
svg = db.get_overlay(surf_id)
diff --git a/cortex/webgl/resources/js/mriview.js b/cortex/webgl/resources/js/mriview.js
index ff536df6c..1b6ec52cc 100644
--- a/cortex/webgl/resources/js/mriview.js
+++ b/cortex/webgl/resources/js/mriview.js
@@ -222,6 +222,37 @@ var mriview = (function(module) {
}
this.setData(data[0].name);
+
+ // Populate the contours folder: overlay first, then mode and threshold (only once)
+ if (!this._contourUIAdded) {
+ var contourOptions = {"none": "none"};
+ for (var dname in this.dataviews) {
+ if (this.dataviews[dname].vertex) {
+ contourOptions[dname] = dname;
+ }
+ }
+ this._contourOverlayName = "none";
+ var viewer = this;
+ for (var i = 0; i < this.surfs.length; i++) {
+ (function(surf) {
+ surf.surf.loaded.done(function() {
+ var contoursFolder = surf.surf.ui.contours;
+ // Overlay dropdown first (if multiple vertex datasets)
+ if (Object.keys(contourOptions).length > 1) {
+ contoursFolder.add({
+ overlay: {action:[viewer, "setContourOverlay", contourOptions]},
+ });
+ }
+ // Then mode and threshold
+ contoursFolder.add({
+ mode: {action:[surf.surf, "setContourMode", {off:0, "contours only":1, "contours + fill":2, "colored contours":3, "colored + fill":4}]},
+ threshold: {action:[surf.surf.uniforms.contourThreshold, "value", 0.001, 0.5]},
+ });
+ });
+ })(this.surfs[i]);
+ }
+ this._contourUIAdded = true;
+ }
};
module.Viewer.prototype.setData = function(name) {
@@ -614,6 +645,78 @@ var mriview = (function(module) {
this.playpause();
this.setData([datasets[(i+dir).mod(datasets.length)]]);
};
+ module.Viewer.prototype.setContourOverlay = function(name) {
+ if (name === undefined)
+ return this._contourOverlayName || "none";
+
+ if (name === "none" || name === null || name === 0) {
+ this._contourOverlayName = "none";
+ this.contourOverlay = null;
+ for (var i = 0; i < this.surfs.length; i++) {
+ this.surfs[i].surf.uniforms.contourOverlay.value = 0;
+ }
+ this.schedule();
+ return;
+ }
+
+ var overlayView = this.dataviews[name];
+ if (!overlayView) {
+ console.warn("setContourOverlay: dataset '" + name + "' not found. Available: " +
+ Object.keys(this.dataviews).join(", "));
+ return;
+ }
+ if (!overlayView.vertex) {
+ console.warn("setContourOverlay: dataset '" + name + "' is not vertex data. " +
+ "Contour overlays require vertex (surface) data.");
+ return;
+ }
+
+ this._contourOverlayName = name;
+ this.contourOverlay = name;
+ var overlayData = overlayView.data[0];
+
+ var viewer = this;
+ var applyOverlay = function() {
+ // Use frame 0 for contour overlay data. Parcellation overlays
+ // are typically single-frame (static labels). Multi-frame
+ // contour overlays are not currently supported.
+ var fframe = 0;
+ var verts = overlayData.verts[fframe];
+ var verts1 = overlayData.verts[(fframe+1) % overlayData.verts.length];
+ for (var i = 0; i < viewer.surfs.length; i++) {
+ var surf = viewer.surfs[i].surf;
+ surf.hemis.left.attributes.contourData0.array = verts[0].array;
+ surf.hemis.left.attributes.contourData0.needsUpdate = true;
+ surf.hemis.right.attributes.contourData0.array = verts[1].array;
+ surf.hemis.right.attributes.contourData0.needsUpdate = true;
+ surf.hemis.left.attributes.contourData1.array = verts1[0].array;
+ surf.hemis.left.attributes.contourData1.needsUpdate = true;
+ surf.hemis.right.attributes.contourData1.array = verts1[1].array;
+ surf.hemis.right.attributes.contourData1.needsUpdate = true;
+ surf.uniforms.contourOverlay.value = 1;
+ // Set vmin/vmax and colormap for colored contour lookup
+ surf.uniforms.contourVmin.value = overlayView.vmin[0].value[0];
+ surf.uniforms.contourVmax.value = overlayView.vmax[0].value[0];
+ surf.uniforms.contourColormap.value = overlayView.cmap[0].value;
+ }
+ viewer.schedule();
+ };
+
+ // Data may already be available (verts populated via progress callback)
+ // even though loaded.state() is "pending" (resolve() is never called
+ // for VertexData). Check verts directly.
+ if (overlayData.verts.length > 0) {
+ applyOverlay();
+ } else {
+ // Data not yet available — wait for progress
+ overlayData.loaded.progress(function() {
+ if (overlayData.verts.length > 0) {
+ applyOverlay();
+ }
+ });
+ }
+ };
+
module.Viewer.prototype.rmData = function(name) {
delete this.datasets[name];
$(this.object).find("#datasets li").each(function() {
diff --git a/cortex/webgl/resources/js/mriview_surface.js b/cortex/webgl/resources/js/mriview_surface.js
index 5b1a788ea..c6d8b9a41 100644
--- a/cortex/webgl/resources/js/mriview_surface.js
+++ b/cortex/webgl/resources/js/mriview_surface.js
@@ -71,6 +71,15 @@ var mriview = (function(module) {
contrast: { type:'f', value:parseFloat(viewopts.contrast)},
extratex: { type:'t', value:null},
+ // Contour rendering
+ contourMode: { type:'f', value: 0 },
+ contourThreshold: { type:'f', value: 0.01 },
+ contourColor: { type:'v3', value: new THREE.Vector3(0, 0, 0) },
+ contourOverlay: { type:'f', value: 0 },
+ contourVmin: { type:'f', value: 0 },
+ contourVmax: { type:'f', value: 1 },
+ contourColormap: { type:'t', value: new THREE.DataTexture(new Uint8Array([0,0,0,255]), 1, 1, THREE.RGBAFormat) },
+
// screen: { type:'t', value:this.volumebuf},
// screen_size:{ type:'v2', value:new THREE.Vector2(100, 100)},
}
@@ -110,7 +119,8 @@ var mriview = (function(module) {
sampler: {action:[this, "setSampler", ["nearest", "trilinear"]]},
uniform_illumination: {action:[this, "setUniformIllumination"]},
});
-
+
+ this.ui.addFolder("contours", true);
this.ui.addFolder("curvature", true).add({
brightness: {action:[this.uniforms.brightness, "value", 0, 1]},
@@ -257,6 +267,11 @@ var mriview = (function(module) {
hemi.addAttribute("data2", new THREE.BufferAttribute(new Float32Array(), 1));
hemi.addAttribute("data3", new THREE.BufferAttribute(new Float32Array(), 1));
+ //Queue contour overlay data attributes (pre-sized to vertex count for proper WebGL buffer allocation)
+ var nVerts = hemi.attributes.position.array.length / hemi.attributes.position.itemSize;
+ hemi.addAttribute("contourData0", new THREE.BufferAttribute(new Float32Array(nVerts), 1));
+ hemi.addAttribute("contourData1", new THREE.BufferAttribute(new Float32Array(nVerts), 1));
+
hemi.dynamic = true;
var pivots = {back:new THREE.Group(), front:new THREE.Group()};
pivots.front.add(pivots.back);
@@ -854,6 +869,13 @@ var mriview = (function(module) {
this.object.add(this.mesh);
}
+ module.Surface.prototype.setContourMode = function(val) {
+ if (val === undefined)
+ return this.uniforms.contourMode.value;
+ this.uniforms.contourMode.value = parseFloat(val);
+ viewer.schedule();
+ };
+
module.Surface.prototype.setUniformIllumination = function(val) {
if (val === undefined)
return this.uniforms.emissive.value.x == 1; // Check current state
diff --git a/cortex/webgl/resources/js/shaderlib.js b/cortex/webgl/resources/js/shaderlib.js
index 9eaeb73a7..ea7f77c19 100644
--- a/cortex/webgl/resources/js/shaderlib.js
+++ b/cortex/webgl/resources/js/shaderlib.js
@@ -382,6 +382,16 @@ var Shaderlib = (function() {
"varying vec3 vWorldPosition;",
// "varying float vDrop;",
+ // Contour overlay attributes and varyings
+ "attribute float contourData0;",
+ "attribute float contourData1;",
+ "varying float vContourDataValue;",
+ "varying vec4 vContourColor;",
+ "uniform float contourVmin;",
+ "uniform float contourVmax;",
+ "uniform sampler2D contourColormap;",
+ "uniform float framemix;",
+
"varying vec3 vPos_x[2];",
"#ifdef TWOD",
"varying vec3 vPos_y[2];",
@@ -446,6 +456,12 @@ var Shaderlib = (function() {
"gl_Position = projectionMatrix * modelViewMatrix * vec4( pos, 1.0 );",
"vWorldPosition = pos;",
+
+ // Contour overlay data
+ "vContourDataValue = mix(contourData0, contourData1, framemix);",
+ "float contourRange = contourVmax - contourVmin;",
+ "float contourNorm = contourRange > 0.0 ? clamp((vContourDataValue - contourVmin) / contourRange, 0.0, 1.0) : 0.0;",
+ "vContourColor = texture2D(contourColormap, vec2(contourNorm, 0.0));",
"}"
].join("\n");
@@ -504,7 +520,15 @@ var Shaderlib = (function() {
"varying float vMedial;",
"varying float vThickmix;",
"varying vec3 vWorldPosition;", // the x,y,z coordinates of this pixel
-
+
+ // Contour rendering uniforms and varyings
+ "varying float vContourDataValue;",
+ "varying vec4 vContourColor;",
+ "uniform float contourMode;",
+ "uniform float contourThreshold;",
+ "uniform vec3 contourColor;",
+ "uniform float contourOverlay;",
+
utils.standard_frag_vars,
utils.rand,
utils.edge,
@@ -647,8 +671,41 @@ var Shaderlib = (function() {
"vec4 tColor = (1. - step(.001, vMedial)) * texture2D(extratex, vUv);",
"#endif",
+ // Contour edge detection (uses overlay vertex data even for pixel-shaded volumes)
+ "float contourEdge = contourOverlay > 0.5 ? fwidth(vContourDataValue) : 0.0;",
+ "bool isBorder = contourEdge > contourThreshold;",
+
"gl_FragColor = cColor;",
- "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ // contourMode: 0=off, 1=contours only, 2=contours+fill,
+ // 3=colored contours only, 4=colored contours+fill
+ "if (contourMode < 0.5) {",
+ // Mode 0: normal rendering, no contours
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "} else if (contourMode < 1.5) {",
+ // Mode 1: contours only (interior = curvature, no data fill)
+ "if (isBorder) {",
+ "vec4 borderColor = vec4(contourColor, 1.0);",
+ "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;",
+ "}",
+ "} else if (contourMode < 2.5) {",
+ // Mode 2: data + contour borders on top
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "if (isBorder) {",
+ "vec4 borderColor = vec4(contourColor, 1.0);",
+ "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;",
+ "}",
+ "} else if (contourMode < 3.5) {",
+ // Mode 3: colored contours only (no data fill)
+ "if (isBorder) {",
+ "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;",
+ "}",
+ "} else {",
+ // Mode 4: data + colored contour borders on top
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "if (isBorder) {",
+ "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;",
+ "}",
+ "}",
// "gl_FragColor = hColor + (1.-hColor.a)*gl_FragColor;",
"#ifdef ROI_RENDER",
"gl_FragColor = rColor + (1.-rColor.a)*gl_FragColor;",
@@ -674,6 +731,9 @@ var Shaderlib = (function() {
attributes.flatBumpNorms = { type: 'v3', value:null };
attributes.flatheight = { type: 'f', value:null };
}
+ attributes['contourData0'] = {type:'f', value:null};
+ attributes['contourData1'] = {type:'f', value:null};
+
for (var i = 0; i < morphs-1; i++) {
attributes['mixSurfs'+i] = { type:'v4', value:null};
attributes['mixNorms'+i] = { type:'v3', value:null};
@@ -730,6 +790,14 @@ var Shaderlib = (function() {
"varying vec2 vUv;",
"varying float vCurv;",
"varying float vMedial;",
+ "varying float vDataValue;",
+ "attribute float contourData0;",
+ "attribute float contourData1;",
+ "varying float vContourDataValue;",
+ "varying vec4 vContourColor;",
+ "uniform float contourVmin;",
+ "uniform float contourVmax;",
+ "uniform sampler2D contourColormap;",
// "varying float vDrop;",
utils.mixer(morphs),
@@ -743,6 +811,7 @@ var Shaderlib = (function() {
"#ifdef RGBCOLORS",
"vColor = mix(data0, data1, framemix);",
+ "vDataValue = 0.0;",
"#else",
"vec2 cuv;",
// "vValue.x = (mix(data0, data1, framemix) - vmin[0]) / (vmax[0] - vmin[0]);",
@@ -752,7 +821,13 @@ var Shaderlib = (function() {
"cuv.y = (mix(data2, data3, framemix) - vmin[1]) / (vmax[1] - vmin[1]);",
"#endif",
"vColor = texture2D(colormap, cuv);",
+ "vDataValue = mix(data0, data1, framemix);",
"#endif",
+ "vContourDataValue = mix(contourData0, contourData1, framemix);",
+ // Look up contour color in the overlay's own colormap
+ "float contourRange = contourVmax - contourVmin;",
+ "float contourNorm = contourRange > 0.0 ? clamp((vContourDataValue - contourVmin) / contourRange, 0.0, 1.0) : 0.0;",
+ "vContourColor = texture2D(contourColormap, vec2(contourNorm, 0.0));",
"#ifdef CORTSHEET",
"vec3 mpos = mix(position, wm.xyz, use_thickmix);",
@@ -806,9 +881,19 @@ var Shaderlib = (function() {
// "varying float vDrop;",
"varying float vCurv;",
"varying float vMedial;",
+ "varying float vDataValue;",
+ "varying float vContourDataValue;",
+ "varying vec4 vContourColor;",
"uniform float thickmix;",
// utils.thickmixer,
+ // Contour rendering uniforms
+ // 0=off, 1=contours only, 2=contours+fill, 3=colored contours only, 4=colored contours+fill
+ "uniform float contourMode;",
+ "uniform float contourThreshold;",
+ "uniform vec3 contourColor;",
+ "uniform float contourOverlay;", // 0=use self data, 1=use overlay data
+
utils.standard_frag_vars,
// "#ifdef RGBCOLORS",
@@ -836,13 +921,41 @@ var Shaderlib = (function() {
"vec4 tColor = (1. - step(.001, vMedial)) * texture2D(extratex, vUv);",
"#endif",
- // "#ifndef RGBCOLORS",
- // "vec4 vColor = texture2D(colormap, vValue);",
- // "#endif",
-
+ // Contour edge detection
+ "float contourEdge = contourOverlay > 0.5 ? fwidth(vContourDataValue) : fwidth(vDataValue);",
+ "bool isBorder = contourEdge > contourThreshold;",
"gl_FragColor = cColor;",
- "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ // contourMode: 0=off, 1=contours only, 2=contours+fill,
+ // 3=colored contours only, 4=colored contours+fill
+ "if (contourMode < 0.5) {",
+ // Mode 0: normal rendering, no contours
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "} else if (contourMode < 1.5) {",
+ // Mode 1: contours only (interior = curvature)
+ "if (isBorder) {",
+ "vec4 borderColor = vec4(contourColor, 1.0);",
+ "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;",
+ "}",
+ "} else if (contourMode < 2.5) {",
+ // Mode 2: data + contour borders on top
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "if (isBorder) {",
+ "vec4 borderColor = vec4(contourColor, 1.0);",
+ "gl_FragColor = borderColor + (1.-borderColor.a)*gl_FragColor;",
+ "}",
+ "} else if (contourMode < 3.5) {",
+ // Mode 3: colored contours only (border color from overlay colormap)
+ "if (isBorder) {",
+ "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;",
+ "}",
+ "} else {",
+ // Mode 4: data + colored contour borders on top
+ "gl_FragColor = vColor + (1.-vColor.a)*gl_FragColor;",
+ "if (isBorder) {",
+ "gl_FragColor = vContourColor + (1.-vContourColor.a)*gl_FragColor;",
+ "}",
+ "}",
//"gl_FragColor = vec4(1., 0., 0., 1.);",
// "gl_FragColor = hColor + (1.-hColor.a)*gl_FragColor;",
"#ifdef ROI_RENDER",
@@ -866,6 +979,9 @@ var Shaderlib = (function() {
for (var i = 0; i < 4; i++)
attributes['data'+i] = {type:opts.rgb ? 'v4':'f', value:null};
+ attributes['contourData0'] = {type:'f', value:null};
+ attributes['contourData1'] = {type:'f', value:null};
+
for (var i = 0; i < morphs-1; i++) {
attributes['mixSurfs'+i] = { type:'v4', value:null };
attributes['mixNorms'+i] = { type:'v3', value:null };
diff --git a/cortex/webgl/view.py b/cortex/webgl/view.py
index d32ad7916..242ccd2cc 100644
--- a/cortex/webgl/view.py
+++ b/cortex/webgl/view.py
@@ -28,19 +28,25 @@
from .FallbackLoader import FallbackLoader
try:
- cmapdir = options.config.get('webgl', 'colormaps')
+ cmapdir = options.config.get("webgl", "colormaps")
if not os.path.exists(cmapdir):
- raise Exception("Colormap directory (%s) does not exist"%cmapdir)
+ raise Exception("Colormap directory (%s) does not exist" % cmapdir)
except NoOptionError:
cmapdir = os.path.join(options.config.get("basic", "filestore"), "colormaps")
if not os.path.exists(cmapdir):
- raise Exception("Colormap directory was not defined in the config file and the default (%s) does not exist"%cmapdir)
+ raise Exception(
+ "Colormap directory was not defined in the config file and the default (%s) does not exist"
+ % cmapdir
+ )
domain_name = options.config.get("webgl", "domain_name")
colormaps = glob.glob(os.path.join(cmapdir, "*.png"))
-colormaps = [(os.path.splitext(os.path.split(cm)[1])[0], serve.make_base64(cm))
- for cm in sorted(colormaps)]
+colormaps = [
+ (os.path.splitext(os.path.split(cm)[1])[0], serve.make_base64(cm))
+ for cm in sorted(colormaps)
+]
+
def make_static(
outpath,
@@ -130,7 +136,7 @@ def make_static(
Smoothness of curvature overlay. Default None, which uses the value
specified in the config file.
surface_specularity : float or None, optional
- Specularity of surfaces visualized with the WebGL viewer.
+ Specularity of surfaces visualized with the WebGL viewer.
Default None, which uses the value specified in the config file under
`webgl_viewopts.specularity`.
**kwargs
@@ -285,24 +291,24 @@ def make_static(
def show(
data: Union[dataset.Dataset, dataset.Dataview],
- autoclose: Optional[bool]=None,
- open_browser: Optional[bool]=None,
- port: Optional[int]=None,
- pickerfun: Optional[Callable[[tuple[int, int, int], int, str], None]]=None,
- recache: bool=False,
- template: str="mixer.html",
- overlays_available: Optional[tuple[str, ...]]=None,
- overlays_visible: Optional[tuple[str, ...]]=("rois", "sulci"),
- labels_visible: Optional[tuple[str, ...]]=("rois",),
- types: Optional[tuple[str, ...]]=("inflated",),
- overlay_file: Optional[str]=None,
- curvature_brightness: Optional[float]=None,
- curvature_contrast: Optional[float]=None,
- curvature_smoothness: Optional[float]=None,
- surface_specularity: Optional[float]=None,
- title: str="Brain",
- layout: Optional[str]=None,
- display_url: bool=True,
+ autoclose: Optional[bool] = None,
+ open_browser: Optional[bool] = None,
+ port: Optional[int] = None,
+ pickerfun: Optional[Callable[[tuple[int, int, int], int, str], None]] = None,
+ recache: bool = False,
+ template: str = "mixer.html",
+ overlays_available: Optional[tuple[str, ...]] = None,
+ overlays_visible: Optional[tuple[str, ...]] = ("rois", "sulci"),
+ labels_visible: Optional[tuple[str, ...]] = ("rois",),
+ types: Optional[tuple[str, ...]] = ("inflated",),
+ overlay_file: Optional[str] = None,
+ curvature_brightness: Optional[float] = None,
+ curvature_contrast: Optional[float] = None,
+ curvature_smoothness: Optional[float] = None,
+ surface_specularity: Optional[float] = None,
+ title: str = "Brain",
+ layout: Optional[str] = None,
+ display_url: bool = True,
**kwargs,
):
"""
@@ -338,7 +344,7 @@ def show(
overlays_available : tuple, optional
Overlays available in the viewer. If None, then all overlay layers of the
svg file will be potentially available in the viewer (whether initially
- visible or not).
+ visible or not).
overlays_visible : tuple, optional
The listed overlay layers will be set visible by default. Layers not listed
here will be hidden by default (but can be enabled in the viewer GUI).
@@ -366,7 +372,7 @@ def show(
Smoothness of curvature overlay. Default None, which uses the value
specified in the config file.
surface_specularity : float or None, optional
- Specularity of surfaces visualized with the WebGL viewer.
+ Specularity of surfaces visualized with the WebGL viewer.
Default None, which uses the value specified in the config file under
`webgl_viewopts.specularity`.
title : str, optional
@@ -387,52 +393,61 @@ def show(
# populate default webshow args
if autoclose is None:
- autoclose = options.config.get('webshow', 'autoclose', fallback='true') == 'true'
+ autoclose = (
+ options.config.get("webshow", "autoclose", fallback="true") == "true"
+ )
if open_browser is None:
- open_browser = options.config.get('webshow', 'open_browser', fallback='true') == 'true'
+ open_browser = (
+ options.config.get("webshow", "open_browser", fallback="true") == "true"
+ )
data = dataset.normalize(data)
if not isinstance(data, dataset.Dataset):
data = dataset.Dataset(data=data)
- html = FallbackLoader([os.path.split(os.path.abspath(template))[0], serve.cwd]).load(template)
+ html = FallbackLoader(
+ [os.path.split(os.path.abspath(template))[0], serve.cwd]
+ ).load(template)
db.auxfile = data
- #Extract the list of stimuli, for special-casing
+ # Extract the list of stimuli, for special-casing
stims: dict[str, str] = dict()
for name, view in data:
- if 'stim' in view.attrs and os.path.exists(view.attrs['stim']):
- sname = os.path.split(view.attrs['stim'])[1]
- stims[sname] = view.attrs['stim']
+ if "stim" in view.attrs and os.path.exists(view.attrs["stim"]):
+ sname = os.path.split(view.attrs["stim"])[1]
+ stims[sname] = view.attrs["stim"]
package = Package(data)
metadata = json.dumps(package.metadata())
images = package.images
subjects = list(package.subjects)
- ctmargs = dict(method='mg2', level=9, recache=recache,
- external_svg=overlay_file, overlays_available=overlays_available)
- ctms = dict((subj, utils.get_ctmpack(subj, types, **ctmargs))
- for subj in subjects)
+ ctmargs = dict(
+ method="mg2",
+ level=9,
+ recache=recache,
+ external_svg=overlay_file,
+ overlays_available=overlays_available,
+ )
+ ctms = dict((subj, utils.get_ctmpack(subj, types, **ctmargs)) for subj in subjects)
package.reorder(ctms)
- subjectjs = json.dumps(dict((subj, "ctm/%s/"%subj) for subj in subjects))
+ subjectjs = json.dumps(dict((subj, "ctm/%s/" % subj) for subj in subjects))
db.auxfile = None
-
- linear = lambda x, y, m: (1.-m)*x + m*y
+ linear = lambda x, y, m: (1.0 - m) * x + m * y
mixes = dict(
linear=linear,
- smoothstep=(lambda x, y, m: linear(x, y, 3*m**2 - 2*m**3)),
- smootherstep=(lambda x, y, m: linear(x, y, 6*m**5 - 15*m**4 + 10*m**3))
+ smoothstep=(lambda x, y, m: linear(x, y, 3 * m**2 - 2 * m**3)),
+ smootherstep=(lambda x, y, m: linear(x, y, 6 * m**5 - 15 * m**4 + 10 * m**3)),
)
post_name: Queue[str] = Queue()
# Put together all view options
- my_viewopts: dict[str, Any] = dict(options.config.items('webgl_viewopts'))
- my_viewopts['overlays_visible'] = overlays_visible
- my_viewopts['labels_visible'] = labels_visible
+ my_viewopts: dict[str, Any] = dict(options.config.items("webgl_viewopts"))
+ my_viewopts["overlays_visible"] = overlays_visible
+ my_viewopts["labels_visible"] = labels_visible
my_viewopts["brightness"] = (
options.config.get("curvature", "brightness")
if curvature_brightness is None
@@ -455,7 +470,7 @@ def show(
)
for sec in options.config.sections():
- if 'paths' in sec or 'labels' in sec:
+ if "paths" in sec or "labels" in sec:
my_viewopts[sec] = dict(options.config.items(sec))
if pickerfun is None:
@@ -463,8 +478,8 @@ def show(
class CTMHandler(web.RequestHandler):
def get(self, path: str):
- subj, path = path.split('/')
- if path == '':
+ subj, path = path.split("/")
+ if path == "":
self.set_header("Content-Type", "application/json")
self.write(open(ctms[subj]).read())
else:
@@ -473,14 +488,14 @@ def get(self, path: str):
if mtype is None:
mtype = "application/octet-stream"
self.set_header("Content-Type", mtype)
- self.write(open(os.path.join(fpath, path), 'rb').read())
+ self.write(open(os.path.join(fpath, path), "rb").read())
class DataHandler(web.RequestHandler):
def get(self, path: str):
path = path.strip("/")
frame: Union[int, str]
try:
- dataname, frame = path.split('/')
+ dataname, frame = path.split("/")
except ValueError:
dataname = path
frame = 0
@@ -492,15 +507,21 @@ def get(self, path: str):
else:
self.set_header("Content-Type", "image/png")
- if 'Range' in self.request.headers:
+ if "Range" in self.request.headers:
self.set_status(206)
- rangestr = self.request.headers['Range'].split('=')[1]
- start, end = [ int(i) if len(i) > 0 else None for i in rangestr.split('-') ]
-
- clenheader = 'bytes %s-%s/%s' % (start, end or len(dataimg), len(dataimg) )
- self.set_header('Content-Range', clenheader)
- self.set_header('Content-Length', end-start+1)
- self.write(dataimg[start:end+1])
+ rangestr = self.request.headers["Range"].split("=")[1]
+ start, end = [
+ int(i) if len(i) > 0 else None for i in rangestr.split("-")
+ ]
+
+ clenheader = "bytes %s-%s/%s" % (
+ start,
+ end or len(dataimg),
+ len(dataimg),
+ )
+ self.set_header("Content-Range", clenheader)
+ self.set_header("Content-Length", end - start + 1)
+ self.write(dataimg[start : end + 1])
else:
self.write(dataimg)
else:
@@ -521,24 +542,26 @@ def get(self, path: str):
class StaticHandler(web.StaticFileHandler):
def initialize(self):
- self.root = ''
+ self.root = ""
class MixerHandler(web.RequestHandler):
def get(self):
self.set_header("Content-Type", "text/html")
- generated = html.generate(data=metadata,
- colormaps=colormaps,
- default_cmap="RdBu_r",
- python_interface=True,
- leapmotion=True,
- layout=layout,
- subjects=subjectjs,
- viewopts=json.dumps(my_viewopts),
- title=title,
- **kwargs)
- #overlays_visible=json.dumps(overlays_visible),
- #labels_visible=json.dumps(labels_visible),
- #**viewopts)
+ generated = html.generate(
+ data=metadata,
+ colormaps=colormaps,
+ default_cmap="RdBu_r",
+ python_interface=True,
+ leapmotion=True,
+ layout=layout,
+ subjects=subjectjs,
+ viewopts=json.dumps(my_viewopts),
+ title=title,
+ **kwargs,
+ )
+ # overlays_visible=json.dumps(overlays_visible),
+ # labels_visible=json.dumps(labels_visible),
+ # **viewopts)
self.write(generated)
def post(self):
@@ -554,13 +577,13 @@ def post(self):
data = png
svgfile.write(data)
- P = ParamSpec('P')
+ P = ParamSpec("P")
class JSMixer(serve.JSProxy[P]):
@property
def view_props(self) -> list[str]:
- """An enumerated list of settable properties for views.
- There may be a way to get this from the javascript object,
+ """An enumerated list of settable properties for views.
+ There may be a way to get this from the javascript object,
but I (ML) don't know how.
There may be additional properties we want to set in views
@@ -572,15 +595,24 @@ def view_props(self) -> list[str]:
'volume_vis', 'frame', 'slices']
"""
camera = getattr(self.ui, "camera")
- _camera_props = ['camera.%s' % k for k in camera._controls.attrs.keys()]
+ _camera_props = ["camera.%s" % k for k in camera._controls.attrs.keys()]
surface = getattr(self.ui, "surface")
_subject = list(surface._folders.attrs.keys())[0]
_surface = getattr(surface, _subject)
- _surface_props = ['surface.{subject}.%s'%k for k in _surface._controls.attrs.keys()]
- _curvature_props = ['surface.{subject}.curvature.brightness',
- 'surface.{subject}.curvature.contrast',
- 'surface.{subject}.curvature.smoothness']
- return _camera_props + _surface_props + _curvature_props
+ _surface_props = [
+ "surface.{subject}.%s" % k for k in _surface._controls.attrs.keys()
+ ]
+ _curvature_props = [
+ "surface.{subject}.curvature.brightness",
+ "surface.{subject}.curvature.contrast",
+ "surface.{subject}.curvature.smoothness",
+ ]
+ _contour_props = [
+ "surface.{subject}.contours.mode",
+ "surface.{subject}.contours.threshold",
+ "surface.{subject}.contours.overlay",
+ ]
+ return _camera_props + _surface_props + _curvature_props + _contour_props
def _set_view(self, **kwargs):
"""Low-level command: sets view parameters in the current viewer
@@ -588,24 +620,36 @@ def _set_view(self, **kwargs):
Sets each the state of each keyword argument provided. View parameters
that can be set include all parameters in the data.gui in the html view.
+ Contour-related parameters:
+ surface.{subject}.contours.mode : int
+ 0=off, 1=contours only, 2=contours+fill
+ surface.{subject}.contours.threshold : float
+ Edge detection threshold (0.001-0.5)
+ surface.{subject}.contours.overlay : str
+ Dataset name to use as contour overlay, or "none"
+
"""
# Set unfolding level first, as it interacts with other arguments
assert isinstance(self.ui, serve.JSProxy)
surface: serve.JSProxy[P] = getattr(self.ui, "surface")
subject_list = cast(serve.JSProxy[P], surface._folders).attrs.keys()
- # Better to only self.view_props once; it interacts with javascript,
+ # Better to only self.view_props once; it interacts with javascript,
# don't want to do that too often, it leads to glitches.
vw_props = copy.copy(self.view_props)
for subject in subject_list:
- if 'surface.{subject}.unfold' in kwargs:
- unfold = kwargs.pop('surface.{subject}.unfold')
- self.ui.set('surface.{subject}.unfold'.format(subject=subject), unfold)
+ if "surface.{subject}.unfold" in kwargs:
+ unfold = kwargs.pop("surface.{subject}.unfold")
+ self.ui.set(
+ "surface.{subject}.unfold".format(subject=subject), unfold
+ )
for k, v in kwargs.items():
if not k in vw_props:
- print('Unknown parameter %s!'%k)
+ print("Unknown parameter %s!" % k)
continue
else:
- self.ui.set(k.format(subject=subject) if '{subject}' in k else k, v)
+ self.ui.set(
+ k.format(subject=subject) if "{subject}" in k else k, v
+ )
# Wait for webgl. Wait for it. .... WAAAAAIIIT.
time.sleep(0.03)
@@ -621,7 +665,7 @@ def _capture_view(self, frame_time=None):
----------
frame_time : scalar
time (in seconds) to specify for this frame.
-
+
Notes
-----
If multiple subjects are present, only retrieves view for first subject.
@@ -630,16 +674,18 @@ def _capture_view(self, frame_time=None):
subject = list(self.ui.surface._folders.attrs.keys())[0]
for p in self.view_props:
try:
- view[p] = self.ui.get(p.format(subject=subject) if '{subject}' in p else p)[0]
+ view[p] = self.ui.get(
+ p.format(subject=subject) if "{subject}" in p else p
+ )[0]
# Wait for webgl.
time.sleep(0.03)
except Exception as err:
# TO DO: Fix this hack with an error class in serve.py & catch it here
- print(err) #msg = "Cannot read property 'undefined'"
- #if err.message[:len(msg)] != msg:
+ print(err) # msg = "Cannot read property 'undefined'"
+ # if err.message[:len(msg)] != msg:
# raise err
if frame_time is not None:
- view['time'] = frame_time
+ view["time"] = frame_time
return view
def save_view(self, subject, name, is_overwrite=False):
@@ -683,12 +729,14 @@ def get_view(self, subject, name):
def addData(self, **kwargs):
Proxy = serve.JSProxy(self.send, "window.viewers.addData")
- new_meta, new_ims = _convert_dataset(Dataset(**kwargs), path='/data/', fmt='%s_%d.png')
+ new_meta, new_ims = _convert_dataset(
+ Dataset(**kwargs), path="/data/", fmt="%s_%d.png"
+ )
metadata.update(new_meta)
images.update(new_ims)
return Proxy(metadata)
- def getImage(self, filename: str, size: tuple[int, int]=(1920, 1080)):
+ def getImage(self, filename: str, size: tuple[int, int] = (1920, 1080)):
"""Saves currently displayed view to a .png image file
Parameters
@@ -702,8 +750,15 @@ def getImage(self, filename: str, size: tuple[int, int]=(1920, 1080)):
Proxy = serve.JSProxy(self.send, "window.viewer.getImage")
return Proxy(size[0], size[1], "mixer.html")
- def makeMovie(self, animation, filename="brainmovie%07d.png", offset=0,
- fps=30, size=(1920, 1080), interpolation="linear"):
+ def makeMovie(
+ self,
+ animation,
+ filename="brainmovie%07d.png",
+ offset=0,
+ fps=30,
+ size=(1920, 1080),
+ interpolation="linear",
+ ):
"""Renders movie frames for animation of mesh movement
Makes an animation (for example, a transition between inflated and
@@ -754,39 +809,47 @@ def makeMovie(self, animation, filename="brainmovie%07d.png", offset=0,
# anim is a list of transitions between keyframes
anim = []
setfunc = self.ui.set
- for f in sorted(animation, key=lambda x:x['idx']):
- if f['idx'] == 0:
- setfunc(f['state'], f['value'])
- state[f['state']] = dict(idx=f['idx'], val=f['value'])
+ for f in sorted(animation, key=lambda x: x["idx"]):
+ if f["idx"] == 0:
+ setfunc(f["state"], f["value"])
+ state[f["state"]] = dict(idx=f["idx"], val=f["value"])
else:
- if f['state'] not in state:
- state[f['state']] = dict(idx=0, val=self.getState(f['state'])[0])
- start = dict(idx=state[f['state']]['idx'],
- state=f['state'],
- value=state[f['state']]['val'])
- end = dict(idx=f['idx'], state=f['state'], value=f['value'])
- state[f['state']]['idx'] = f['idx']
- state[f['state']]['val'] = f['value']
- if start['value'] != end['value']:
+ if f["state"] not in state:
+ state[f["state"]] = dict(
+ idx=0, val=self.getState(f["state"])[0]
+ )
+ start = dict(
+ idx=state[f["state"]]["idx"],
+ state=f["state"],
+ value=state[f["state"]]["val"],
+ )
+ end = dict(idx=f["idx"], state=f["state"], value=f["value"])
+ state[f["state"]]["idx"] = f["idx"]
+ state[f["state"]]["val"] = f["value"]
+ if start["value"] != end["value"]:
anim.append((start, end))
- for i, sec in enumerate(np.arange(0, anim[-1][1]['idx']+1./fps, 1./fps)):
+ for i, sec in enumerate(
+ np.arange(0, anim[-1][1]["idx"] + 1.0 / fps, 1.0 / fps)
+ ):
for start, end in anim:
- if start['idx'] < sec <= end['idx']:
- idx = (sec - start['idx']) / float(end['idx'] - start['idx'])
- if start['state'] == 'frame':
- func = mixes['linear']
+ if start["idx"] < sec <= end["idx"]:
+ idx = (sec - start["idx"]) / float(end["idx"] - start["idx"])
+ if start["state"] == "frame":
+ func = mixes["linear"]
else:
func = mixes[interpolation]
- val = func(np.array(start['value']), np.array(end['value']), idx)
+ val = func(
+ np.array(start["value"]), np.array(end["value"]), idx
+ )
if isinstance(val, np.ndarray):
- setfunc(start['state'], val.ravel().tolist())
+ setfunc(start["state"], val.ravel().tolist())
else:
- setfunc(start['state'], val)
- self.getImage(filename%(i+offset), size=size)
+ setfunc(start["state"], val)
+ self.getImage(filename % (i + offset), size=size)
- def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'):
+ def _get_anim_seq(self, keyframes, fps=30, interpolation="linear"):
"""Convert a list of keyframes to a list of EVERY frame in an animation.
Utility function called by make_movie; separated out so that individual
@@ -798,23 +861,23 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'):
fr = 0
a = np.array
func = mixes[interpolation]
- #skip_props = ['surface.{subject}.right', 'surface.{subject}.left', ] #'projection',
+ # skip_props = ['surface.{subject}.right', 'surface.{subject}.left', ] #'projection',
# Get keyframes
- keyframes = sorted(keyframes, key=lambda x:x['time'])
+ keyframes = sorted(keyframes, key=lambda x: x["time"])
# Normalize all time to frame rate
- fs = 1./fps
+ fs = 1.0 / fps
for k in range(len(keyframes)):
- t = keyframes[k]['time']
- t = np.round(t/fs)*fs
- keyframes[k]['time'] = t
+ t = keyframes[k]["time"]
+ t = np.round(t / fs) * fs
+ keyframes[k]["time"] = t
allframes = []
for start, end in zip(keyframes[:-1], keyframes[1:]):
- t0 = start['time']
- t1 = end['time']
- tdif = float(t1-t0)
+ t0 = start["time"]
+ t1 = end["time"]
+ tdif = float(t1 - t0)
# Check whether to continue frame sequence to endpoint
- use_endpoint = keyframes[-1]==end
- nvalues = np.round(tdif/fs).astype(int)
+ use_endpoint = keyframes[-1] == end
+ nvalues = np.round(tdif / fs).astype(int)
if use_endpoint:
nvalues += 1
fr_time = np.linspace(0, 1, nvalues, endpoint=use_endpoint)
@@ -822,9 +885,13 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'):
for t in fr_time:
frame = {}
for prop in start.keys():
- if prop=='time':
+ if prop == "time":
continue
- if (start[prop] is None) or (start[prop] == end[prop]) or isinstance(start[prop], (bool, str)):
+ if (
+ (start[prop] is None)
+ or (start[prop] == end[prop])
+ or isinstance(start[prop], (bool, str))
+ ):
frame[prop] = start[prop]
continue
val = func(a(start[prop]), a(end[prop]), t)
@@ -835,9 +902,18 @@ def _get_anim_seq(self, keyframes, fps=30, interpolation='linear'):
allframes.append(frame)
return allframes
- def make_movie_views(self, animation, filename="brainmovie%07d.png",
- offset=0, fps=30, size=(1920, 1080), alpha=1, frame_sleep=0.05,
- frame_start=0, interpolation="linear"):
+ def make_movie_views(
+ self,
+ animation,
+ filename="brainmovie%07d.png",
+ offset=0,
+ fps=30,
+ size=(1920, 1080),
+ alpha=1,
+ frame_sleep=0.05,
+ frame_start=0,
+ interpolation="linear",
+ ):
"""Renders movie frames for animation of mesh movement
Makes an animation (for example, a transition between inflated and
@@ -894,7 +970,7 @@ def make_movie_views(self, animation, filename="brainmovie%07d.png",
for fr, frame in enumerate(allframes[frame_start:], frame_start):
self._set_view(**frame)
time.sleep(frame_sleep)
- self.getImage(filename%(fr+offset+1), size=size)
+ self.getImage(filename % (fr + offset + 1), size=size)
time.sleep(frame_sleep)
class PickerHandler(web.RequestHandler):
@@ -907,7 +983,9 @@ def get(self):
parts = voxel_arg.split(",")
if len(parts) != 3:
self.set_status(400)
- self.finish("Invalid 'voxel' query parameter: expected 3 comma-separated integers")
+ self.finish(
+ "Invalid 'voxel' query parameter: expected 3 comma-separated integers"
+ )
return
try:
voxel: tuple[int, int, int] = tuple(int(i) for i in parts)
@@ -921,6 +999,7 @@ def get(self):
class WebApp(serve.WebApp):
disconnect_on_close = autoclose
+
def get_client(self):
self.connect.wait()
self.connect.clear()
@@ -932,18 +1011,22 @@ def get_local_client(self):
if port is None:
port = random.randint(1024, 65536)
- server = WebApp([(r'/ctm/(.*)', CTMHandler),
- (r'/data/(.*)', DataHandler),
- (r'/stim/(.*)', StimHandler),
- (r'/mixer.html', MixerHandler),
- (r'/picker', PickerHandler),
- (r'/', MixerHandler),
- (r'/static/(.*)', StaticHandler)],
- port)
+ server = WebApp(
+ [
+ (r"/ctm/(.*)", CTMHandler),
+ (r"/data/(.*)", DataHandler),
+ (r"/stim/(.*)", StimHandler),
+ (r"/mixer.html", MixerHandler),
+ (r"/picker", PickerHandler),
+ (r"/", MixerHandler),
+ (r"/static/(.*)", StaticHandler),
+ ],
+ port,
+ )
server.start()
- print("Started server on port %d"%server.port)
- url = "http://%s%s:%d/mixer.html"%(serve.hostname, domain_name, server.port)
+ print("Started server on port %d" % server.port)
+ url = "http://%s%s:%d/mixer.html" % (serve.hostname, domain_name, server.port)
if open_browser:
webbrowser.open(url)
client = server.get_client()
@@ -952,7 +1035,10 @@ def get_local_client(self):
elif display_url:
try:
from IPython.display import HTML, display
- display(HTML('Open viewer: {0}'.format(url)))
+
+ display(
+ HTML('Open viewer: {0}'.format(url))
+ )
except:
pass
diff --git a/examples/quickflat/plot_contours.py b/examples/quickflat/plot_contours.py
new file mode 100644
index 000000000..4eefe5931
--- /dev/null
+++ b/examples/quickflat/plot_contours.py
@@ -0,0 +1,108 @@
+"""
+===============================
+Plot parcellation contour lines
+===============================
+
+Parcellation contour lines can be overlaid on top of data to delineate
+region boundaries without obscuring the underlying activation map.
+
+This is useful when you want to show, for example, fMRI activation data
+with anatomical or functional parcellation borders drawn on top.
+
+The ``with_contours`` parameter accepts a :class:`cortex.Vertex` (or any
+Dataview) whose label boundaries will be drawn as contour lines. You can
+customise the line color with ``contour_linecolor`` and the line width
+with ``contour_linewidth``.
+"""
+
+import cortex
+import matplotlib.pyplot as plt
+import numpy as np
+from collections import deque
+
+np.random.seed(1234)
+
+subject = "S1"
+n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0]
+
+###############################################################################
+# Create a random parcellation
+# ----------------------------
+# We generate a parcellation by growing 30 random seed vertices across the
+# mesh using breadth-first search. Each seed becomes a parcel.
+
+n_parcels = 30
+_, polys = cortex.db.get_surf(subject, "fiducial", merge=True)
+neighbors = cortex.utils._get_neighbors_dict(polys)
+
+parcellation = np.zeros(n_verts, dtype=float)
+seeds = np.random.choice(n_verts, n_parcels, replace=False)
+for i, s in enumerate(seeds, 1):
+ parcellation[s] = float(i)
+
+queue = deque(seeds.tolist())
+while queue:
+ v = queue.popleft()
+ for nb in neighbors.get(v, []):
+ if nb < n_verts and parcellation[nb] == 0:
+ parcellation[nb] = parcellation[v]
+ queue.append(nb)
+
+# Create Vertex objects
+parc_vertex = cortex.Vertex(parcellation, subject, cmap="Set1", vmin=0, vmax=n_parcels)
+activation = cortex.Vertex(
+ np.random.randn(n_verts), subject, cmap="RdBu_r", vmin=-2, vmax=2
+)
+
+###############################################################################
+# Activation data with parcellation contours
+# -------------------------------------------
+# Pass a Dataview to ``with_contours`` to draw its label boundaries on top
+# of the primary data.
+
+fig = cortex.quickshow(
+ activation,
+ with_contours=parc_vertex,
+ with_curvature=True,
+ with_rois=False,
+ with_colorbar=True,
+ height=1024,
+)
+fig.suptitle("Activation + parcellation contours", fontsize=14)
+plt.show()
+
+###############################################################################
+# Custom contour color and width
+# ------------------------------
+# Use ``contour_linecolor`` (RGBA tuple) and ``contour_linewidth`` (pixels)
+# to customise the contour appearance.
+
+fig = cortex.quickshow(
+ activation,
+ with_contours=parc_vertex,
+ contour_linecolor=(1, 0, 0, 1),
+ contour_linewidth=3,
+ with_curvature=True,
+ with_rois=False,
+ with_colorbar=False,
+ height=1024,
+)
+fig.suptitle("Red thick contours", fontsize=14)
+plt.show()
+
+###############################################################################
+# Parcellation contours on curvature
+# -----------------------------------
+# You can also overlay the contours of the parcellation on its own data
+# to see both the filled colors and the borders.
+
+fig = cortex.quickshow(
+ parc_vertex,
+ with_contours=parc_vertex,
+ with_curvature=True,
+ with_rois=False,
+ with_colorbar=False,
+ height=1024,
+)
+fig.suptitle("Parcellation with contour borders", fontsize=14)
+plt.show()
diff --git a/examples/webgl/plot_contours_headless.py b/examples/webgl/plot_contours_headless.py
new file mode 100644
index 000000000..fbbd321de
--- /dev/null
+++ b/examples/webgl/plot_contours_headless.py
@@ -0,0 +1,107 @@
+"""
+===============================================
+Plot parcellation contours on 3D brain (headless)
+===============================================
+
+The WebGL viewer supports contour rendering of parcellation borders on
+the 3D cortical surface. ``cortex.export.save_3d_views`` accepts a
+``contour_overlay`` parameter — pass a ``cortex.Vertex`` with parcellation
+labels and the function automatically bundles it with the primary data,
+enabling contour borders in the rendered views.
+
+Available contour modes (``contour_mode``):
+
+- ``"contours"``: borders only on curvature
+- ``"contours+fill"``: data with solid-colour borders (default)
+- ``"colored"``: borders coloured by the overlay's colormap
+- ``"colored+fill"``: data with colormap-coloured borders
+
+Prerequisites
+-------------
+Install Playwright and download the bundled Chromium binary once::
+
+ pip install playwright
+ playwright install chromium
+
+"""
+
+import os
+import tempfile
+from collections import deque
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+import cortex
+import cortex.export
+
+np.random.seed(1234)
+
+subject = "S1"
+n_verts = cortex.db.get_surf(subject, "fiducial", merge=True)[0].shape[0]
+
+###############################################################################
+# Create a random parcellation
+# ----------------------------
+# Grow 30 random seed vertices across the mesh using breadth-first search.
+
+n_parcels = 30
+_, polys = cortex.db.get_surf(subject, "fiducial", merge=True)
+neighbors = cortex.utils._get_neighbors_dict(polys)
+
+parcellation = np.zeros(n_verts, dtype=float)
+seeds = np.random.choice(n_verts, n_parcels, replace=False)
+for i, s in enumerate(seeds, 1):
+ parcellation[s] = float(i)
+
+queue = deque(seeds.tolist())
+while queue:
+ v = queue.popleft()
+ for nb in neighbors.get(v, []):
+ if nb < n_verts and parcellation[nb] == 0:
+ parcellation[nb] = parcellation[v]
+ queue.append(nb)
+
+###############################################################################
+# Create data and parcellation Vertex objects
+# --------------------------------------------
+
+activation = cortex.Vertex(
+ np.random.randn(n_verts), subject, cmap="RdBu_r", vmin=-2, vmax=2
+)
+parc_vertex = cortex.Vertex(parcellation, subject, cmap="Set1", vmin=0, vmax=n_parcels)
+
+###############################################################################
+# Render with parcellation contour overlay
+# ------------------------------------------
+# Pass the parcellation ``Vertex`` directly as ``contour_overlay``.
+# The function wraps both into a Dataset automatically.
+# The default ``contour_mode="contours+fill"`` draws black borders.
+
+base_name = os.path.join(tempfile.mkdtemp(), "contour")
+
+fnames = cortex.export.save_3d_views(
+ activation,
+ base_name=base_name,
+ list_angles=["left"],
+ list_surfaces=["inflated"],
+ viewer_params=dict(labels_visible=[], overlays_visible=[]),
+ size=(1920 * 2, 1080 * 2),
+ trim=True,
+ headless=True,
+ contour_overlay=parc_vertex,
+)
+
+for fname in fnames:
+ img = plt.imread(fname)
+ aspect = img.shape[0] / img.shape[1]
+ fig, ax = plt.subplots(figsize=(10, 10 * aspect))
+ ax.imshow(img)
+ ax.axis("off")
+ ax.set_title(
+ "Activation + parcellation contours (inflated, left)",
+ fontsize=14,
+ fontweight="bold",
+ )
+ fig.subplots_adjust(left=0, right=1, top=0.92, bottom=0)
+ plt.show()