-
**A package for 3D cell segmentation with deep learning, including a napari plugin**: training, inference, and data review. In particular, this project was developed for analysis of confocal and mesoSPIM-acquired (cleared tissue + lightsheet) tissue datasets, but is not limited to this type of data. [Check out our publication for more information!](https://elifesciences.org/articles/99848)
-

-
## Installation
- 💻 See the [Installation page](https://adaptivemotorcontrollab.github.io/CellSeg3D/welcome.html) in the documentation for detailed instructions.
+💻 See the [Installation page](https://adaptivemotorcontrollab.github.io/CellSeg3D/welcome.html) in the documentation for detailed instructions.
## Documentation
@@ -33,21 +32,22 @@ pip install napari_cellseg3d
```
To use the plugin, please run:
+
```
napari
```
+
Then go into `Plugins > napari_cellseg3d`, and choose which tool to use.
- **Review (label)**: This module allows you to review your labels, from predictions or manual labeling, and correct them if needed. It then saves the status of each file in a csv, for easier monitoring.
- **Inference**: This module allows you to use pre-trained segmentation algorithms on volumes to automatically label cells and compute statistics.
-- **Train**: This module allows you to train segmentation algorithms from labeled volumes.
+- **Train**: This module allows you to train segmentation algorithms from labeled volumes.
- **Utilities**: This module allows you to perform several actions like cropping your volumes and labels dynamically, by selecting a fixed size volume and moving it around the image; fragment images into smaller cubes for training; or converting labels from instance to segmentation and the opposite.
## Why use CellSeg3D?
The strength of our approach is we can match supervised model performance with purely self-supervised learning, meaning users don't need to spend (hundreds) of hours on annotation. Here is a quick look of our key results. TL;DR see panel **f**, which shows that with minmal input data we can outperform supervised models:
-

**Figure 1. Performance of 3D Semantic and Instance Segmentation Models.**
@@ -59,7 +59,6 @@ F1-score is computed from the Intersection over Union (IoU) with ground truth la
**c:** View of 3D instance labels from supervised models, as noted, for visual cortex volume in b evaluation.
**d:** Illustration of our WNet3D architecture showcasing the dual 3D U-Net structure with our modifications.
-
## News
### **CellSeg3D now published at eLife**
@@ -69,15 +68,19 @@ Read the [article here !](https://elifesciences.org/articles/99848)
### **New version: v0.2.2**
- v0.2.2:
+
- Updated the Colab Notebooks for training and inference
- New models available in the inference demo notebook
- CRF optional post-processing adjustments (and pip install directly)
+
- v0.2.1:
+
- Updated plugin default behaviors across the board to be more readily applicable to demo data
- Threshold value in inference is now automatically set by default according to performance on demo data on a per-model basis
- Added a grid search utility to find best thresholds for supervised models
- v0.2.0:
+
- Changed project name to "napari_cellseg3d" to avoid setuptools deprecation
- Small API changes for training/inference from a script
- Some fixes to WandB integration and csv saving after training
@@ -97,7 +100,6 @@ Previous additions:
- New utilities
- Many small improvements and many bug fixes
-
## Requirements
**Compatible with Python 3.8 to 3.10.**
@@ -118,39 +120,38 @@ Please reach out if you have any issues with the installation, we will be happy
To avoid issues when installing on the ARM64 architecture, please follow these steps.
-1) Create a new conda env using the provided conda/napari_CellSeg3D_ARM64.yml file :
+1. Create a new conda env using the provided conda/napari_CellSeg3D_ARM64.yml file :
- git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git
- cd CellSeg3d
- conda env create -f conda/napari_CellSeg3D_ARM64.yml
- conda activate napari_CellSeg3D_ARM64
-
-
-2) Install a Qt backend (PySide or PyQt5)
-3) Launch napari, the plugin should be available in the plugins menu.
+ ```
+ git clone https://github.com/AdaptiveMotorControlLab/CellSeg3d.git
+ cd CellSeg3d
+ conda env create -f conda/napari_CellSeg3D_ARM64.yml
+ conda activate napari_CellSeg3D_ARM64
+ ```
+1. Install a Qt backend (PySide or PyQt5)
+1. Launch napari, the plugin should be available in the plugins menu.
## Issues
**Help us make the code better by reporting issues and adding your feature requests!**
-
If you encounter any problems, please [file an issue] along with a detailed description.
## Testing
-You can generate docs locally by running ``make html`` in the docs/ folder.
+You can generate docs locally by running `make html` in the docs/ folder.
-Before testing, install all requirements using ``pip install napari-cellseg3d[test]``.
+Before testing, install all requirements using `pip install napari-cellseg3d[test]`.
-``pydensecrf`` is also required for testing.
+`pydensecrf` is also required for testing.
To run tests locally:
-- Locally : run ``pytest napari_cellseg3d\_tests`` in the plugin folder.
-- Locally with coverage : In the plugin folder, run ``coverage run --source=napari_cellseg3d -m pytest`` then ``coverage xml`` to generate a .xml coverage file.
-- With tox : run ``tox`` in the plugin folder (will simulate tests with several python and OS configs, requires substantial storage space)
+- Locally : run `pytest napari_cellseg3d\_tests` in the plugin folder.
+- Locally with coverage : In the plugin folder, run `coverage run --source=napari_cellseg3d -m pytest` then `coverage xml` to generate a .xml coverage file.
+- With tox : run `tox` in the plugin folder (will simulate tests with several python and OS configs, requires substantial storage space)
## Contributing
@@ -170,23 +171,6 @@ Distributed under the terms of the [MIT] license.
"napari-cellseg3d" is free and open source software.
-[napari-hub]: https://www.napari-hub.org/plugins/napari-cellseg3d
-
-[file an issue]: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues
-[napari]: https://github.com/napari/napari
-[Cookiecutter]: https://github.com/audreyr/cookiecutter
-[@napari]: https://github.com/napari
-[MIT]: http://opensource.org/licenses/MIT
-[cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin
-[tox]: https://tox.readthedocs.io/en/latest/
-[pip]: https://pypi.org/project/pip/
-[PyPI]: https://pypi.org/
-[Installation page]: https://adaptivemotorcontrollab.github.io/CellSeg3D/source/guides/installation_guide.html
-[the PyTorch website for installation instructions]: https://pytorch.org/get-started/locally/
-[PyTorch]: https://pytorch.org/get-started/locally/
-[MONAI's optional dependencies]: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies
-[MONAI]: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies
-
## Citation
```
@@ -220,3 +204,14 @@ Please refer to the documentation for full acknowledgements.
## Plugin base
This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template.
+
+[@napari]: https://github.com/napari
+[cookiecutter]: https://github.com/audreyr/cookiecutter
+[cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin
+[file an issue]: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues
+[mit]: http://opensource.org/licenses/MIT
+[monai]: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies
+[monai's optional dependencies]: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies
+[napari]: https://github.com/napari/napari
+[pytorch]: https://pytorch.org/get-started/locally/
+[the pytorch website for installation instructions]: https://pytorch.org/get-started/locally/
diff --git a/docs/TODO.md b/docs/TODO.md
index 88aafee7..f2906380 100644
--- a/docs/TODO.md
+++ b/docs/TODO.md
@@ -1,8 +1,9 @@
-[//]: # (
+\[//\]: # (
TODO:
+
- [ ] Add a way to get the current version of the library
- [x] Update all modules
- [x] Better WNet3D tutorial
- [x] Setup GH Actions
- [ ] Add a bibliography
-)
+ )
diff --git a/docs/source/guides/training_wnet.rst b/docs/source/guides/training_wnet.rst
index 359dc321..429f0d59 100644
--- a/docs/source/guides/training_wnet.rst
+++ b/docs/source/guides/training_wnet.rst
@@ -18,7 +18,7 @@ The WNet3D **does not require a large amount of data to train**, but **choosing
You may find below some guidelines, based on our own data and testing.
-The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can.
+The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can.
.. important::
diff --git a/examples/README.md b/examples/README.md
index be6ea1c4..ab67956e 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -5,4 +5,4 @@ All credits to the original authors of the data.
You can install, launch `napari`, activate the CellSeg3D plugin app, and drag & drop this volume into the canvas.
Then, for example, run the `Inference` module with one of our models.
-See [CellSeg3D documentation](https://adaptivemotorcontrollab.github.io/CellSeg3D/welcome.html) for more details.
\ No newline at end of file
+See [CellSeg3D documentation](https://adaptivemotorcontrollab.github.io/CellSeg3D/welcome.html) for more details.
diff --git a/napari_cellseg3d/_tests/conftest.py b/napari_cellseg3d/_tests/conftest.py
deleted file mode 100644
index bbfeff10..00000000
--- a/napari_cellseg3d/_tests/conftest.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import os
-
-import pytest
-
-
-@pytest.fixture(scope="session", autouse=True)
-def env_config():
- """
- Configure environment variables needed for the test session
- """
-
- # This makes QT render everything offscreen and thus prevents
- # any Modals / Dialogs or other Widgets being rendered on the screen while running unit tests
- os.environ["QT_QPA_PLATFORM"] = "offscreen"
-
- yield
-
- os.environ.pop("QT_QPA_PLATFORM")
diff --git a/napari_cellseg3d/_tests/pytest.ini b/napari_cellseg3d/_tests/pytest.ini
deleted file mode 100644
index 45c3be1c..00000000
--- a/napari_cellseg3d/_tests/pytest.ini
+++ /dev/null
@@ -1,2 +0,0 @@
-[pytest]
-qt_api=pyqt5
diff --git a/napari_cellseg3d/_tests/test_inference.py b/napari_cellseg3d/_tests/test_inference.py
index 62972ba1..4bdc51fa 100644
--- a/napari_cellseg3d/_tests/test_inference.py
+++ b/napari_cellseg3d/_tests/test_inference.py
@@ -63,9 +63,7 @@ def test_load_folder():
def test_inference_on_folder():
config = InferenceWorkerConfig()
config.filetype = ".tif"
- config.images_filepaths = [
- str(Path(__file__).resolve().parent / "res/test.tif")
- ]
+ config.images_filepaths = [str(Path(__file__).resolve().parent / "res/test.tif")]
config.sliding_window_config.window_size = 8
diff --git a/napari_cellseg3d/_tests/test_labels_correction.py b/napari_cellseg3d/_tests/test_labels_correction.py
index b4f13238..c87e2067 100644
--- a/napari_cellseg3d/_tests/test_labels_correction.py
+++ b/napari_cellseg3d/_tests/test_labels_correction.py
@@ -32,9 +32,7 @@ def test_artefact_labeling_utils():
def test_correct_labels():
output_path = res_folder / "test_correct"
output_path.mkdir(exist_ok=True, parents=True)
- cl.relabel_non_unique_i(
- labels, str(output_path / "corrected.tif"), go_fast=True
- )
+ cl.relabel_non_unique_i(labels, str(output_path / "corrected.tif"), go_fast=True)
def test_relabel():
@@ -47,6 +45,4 @@ def test_relabel():
def test_evaluate_model_performance():
- el.evaluate_model_performance(
- labels, labels, print_details=True, visualize=False
- )
+ el.evaluate_model_performance(labels, labels, print_details=True, visualize=False)
diff --git a/napari_cellseg3d/_tests/test_model_framework.py b/napari_cellseg3d/_tests/test_model_framework.py
index 1cb86569..475c334d 100644
--- a/napari_cellseg3d/_tests/test_model_framework.py
+++ b/napari_cellseg3d/_tests/test_model_framework.py
@@ -58,9 +58,7 @@ def test_create_train_dataset_dict(make_napari_viewer_proxy):
def test_log(make_napari_viewer_proxy):
mock_test = "test"
- framework = model_framework.ModelFramework(
- viewer=make_napari_viewer_proxy()
- )
+ framework = model_framework.ModelFramework(viewer=make_napari_viewer_proxy())
framework.log.print_and_log(mock_test)
assert len(framework.log.toPlainText()) != 0
assert framework.log.toPlainText() == "\n" + mock_test
@@ -83,9 +81,7 @@ def test_log(make_napari_viewer_proxy):
def test_display_elements(make_napari_viewer_proxy):
- framework = model_framework.ModelFramework(
- viewer=make_napari_viewer_proxy()
- )
+ framework = model_framework.ModelFramework(viewer=make_napari_viewer_proxy())
framework.display_status_report()
framework.display_status_report()
@@ -96,20 +92,13 @@ def test_display_elements(make_napari_viewer_proxy):
def test_available_models_retrieval(make_napari_viewer_proxy):
- framework = model_framework.ModelFramework(
- viewer=make_napari_viewer_proxy()
- )
+ framework = model_framework.ModelFramework(viewer=make_napari_viewer_proxy())
assert framework.get_available_models() == MODEL_LIST
def test_update_weights_path(make_napari_viewer_proxy):
- framework = model_framework.ModelFramework(
- viewer=make_napari_viewer_proxy()
- )
- assert (
- framework._update_weights_path(framework._default_weights_folder)
- is None
- )
+ framework = model_framework.ModelFramework(viewer=make_napari_viewer_proxy())
+ assert framework._update_weights_path(framework._default_weights_folder) is None
name = str(Path.home() / "test/weight.pth")
framework._update_weights_path([name])
assert framework.weights_config.path == name
diff --git a/napari_cellseg3d/_tests/test_plugin_training.py b/napari_cellseg3d/_tests/test_plugin_training.py
index 32edd0e0..ff56f0f7 100644
--- a/napari_cellseg3d/_tests/test_plugin_training.py
+++ b/napari_cellseg3d/_tests/test_plugin_training.py
@@ -32,9 +32,7 @@ def test_worker_configs(make_napari_viewer_proxy):
]
for attr in dir(default_config):
if not attr.startswith("__") and attr not in excluded:
- assert getattr(default_config, attr) == getattr(
- worker.config, attr
- )
+ assert getattr(default_config, attr) == getattr(worker.config, attr)
# test unsupervised config and worker
widget.model_choice.setCurrentText("WNet3D")
widget._toggle_unsupervised_mode(enabled=True)
@@ -43,9 +41,7 @@ def test_worker_configs(make_napari_viewer_proxy):
excluded = ["results_path_folder", "sample_size", "weights_info"]
for attr in dir(default_config):
if not attr.startswith("__") and attr not in excluded:
- assert getattr(default_config, attr) == getattr(
- worker.config, attr
- )
+ assert getattr(default_config, attr) == getattr(worker.config, attr)
widget.unsupervised_images_filewidget.text_field.setText(
str((im_path.parent / "wnet_test").resolve())
)
diff --git a/napari_cellseg3d/_tests/test_plugin_utils.py b/napari_cellseg3d/_tests/test_plugin_utils.py
index 1e0e01c2..4aa74bf1 100644
--- a/napari_cellseg3d/_tests/test_plugin_utils.py
+++ b/napari_cellseg3d/_tests/test_plugin_utils.py
@@ -21,9 +21,7 @@ def test_utils_plugin(make_napari_viewer_proxy):
view.dims.ndisplay = 3
for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()):
widget.utils_choice.setCurrentIndex(i)
- assert isinstance(
- widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name]
- )
+ assert isinstance(widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name])
if utils_name == "Convert to instance labels":
# to avoid issues with Voronoi-Otsu missing runtime
menu = widget.utils_widgets[i].instance_widgets.method_choice
diff --git a/napari_cellseg3d/_tests/test_review.py b/napari_cellseg3d/_tests/test_review.py
index 98f4f682..52209bed 100644
--- a/napari_cellseg3d/_tests/test_review.py
+++ b/napari_cellseg3d/_tests/test_review.py
@@ -15,9 +15,7 @@ def test_launch_review(make_napari_viewer_proxy):
widget.folder_choice.setChecked(True)
widget.image_filewidget.text_field = im_path
widget.labels_filewidget.text_field = lab_path
- widget.results_filewidget.text_field = str(
- Path(__file__).resolve().parent / "res"
- )
+ widget.results_filewidget.text_field = str(Path(__file__).resolve().parent / "res")
widget.run_review()
widget._viewer.close()
diff --git a/napari_cellseg3d/_tests/test_training.py b/napari_cellseg3d/_tests/test_training.py
index 15ec119e..b5277c5b 100644
--- a/napari_cellseg3d/_tests/test_training.py
+++ b/napari_cellseg3d/_tests/test_training.py
@@ -52,12 +52,8 @@ def test_supervised_training(make_napari_viewer_proxy):
worker_config = widget._set_worker_config()
assert worker_config.model_info.name == "test"
worker = widget._create_supervised_worker_from_config(worker_config)
- worker.config.train_data_dict = [
- {"image": im_path_str, "label": im_path_str}
- ]
- worker.config.val_data_dict = [
- {"image": im_path_str, "label": im_path_str}
- ]
+ worker.config.train_data_dict = [{"image": im_path_str, "label": im_path_str}]
+ worker.config.val_data_dict = [{"image": im_path_str, "label": im_path_str}]
worker.config.max_epochs = 2
worker.config.validation_interval = 2
@@ -97,9 +93,7 @@ def test_unsupervised_training(make_napari_viewer_proxy):
)
# widget.start()
widget.data = widget.create_dataset_dict_no_labs()
- widget.worker = widget._create_worker(
- additional_results_description="wnet_test"
- )
+ widget.worker = widget._create_worker(additional_results_description="wnet_test")
assert widget.worker.config.train_data_dict is not None
widget.worker.config.max_epochs = 1
for res_i in widget.worker.train(
diff --git a/napari_cellseg3d/_tests/test_utils.py b/napari_cellseg3d/_tests/test_utils.py
index 57f7ec7f..57d60bbb 100644
--- a/napari_cellseg3d/_tests/test_utils.py
+++ b/napari_cellseg3d/_tests/test_utils.py
@@ -29,9 +29,7 @@ def test_save_folder():
images = [rand_gen.random((5, 5, 5)).astype(np.float32) for _ in range(10)]
images_paths = [f"{i}.tif" for i in range(10)]
- utils.save_folder(
- test_path, folder_name, images, images_paths, exist_ok=True
- )
+ utils.save_folder(test_path, folder_name, images, images_paths, exist_ok=True)
assert (test_path / folder_name).is_dir()
for i in range(10):
assert (test_path / folder_name / images_paths[i]).is_file()
@@ -51,9 +49,7 @@ def test_sphericities():
mock_surface = random.randint(
100, 1000
) # assuming surface is always larger than volume
- sphericity_vol = utils.sphericity_volume_area(
- mock_volume, mock_surface
- )
+ sphericity_vol = utils.sphericity_volume_area(mock_volume, mock_surface)
assert 0 <= sphericity_vol <= 1
semi_major = random.randint(10, 100)
@@ -65,9 +61,7 @@ def test_sphericities():
except ValueError:
sphericity_axes = 0
if sphericity_axes is None:
- sphericity_axes = (
- 0 # errors already handled in function, returns None
- )
+ sphericity_axes = 0 # errors already handled in function, returns None
assert 0 <= sphericity_axes <= 1
@@ -215,15 +209,11 @@ def test_parse_default_path():
test_path = (Path.home() / "test" / "test" / "test" / "test").as_posix()
path = [test_path, None, None]
- assert utils.parse_default_path(path, check_existence=False) == str(
- test_path
- )
+ assert utils.parse_default_path(path, check_existence=False) == str(test_path)
test_path = (Path.home() / "test" / "does" / "not" / "exist").as_posix()
path = [test_path, None, None]
- assert utils.parse_default_path(path, check_existence=True) == str(
- Path.home()
- )
+ assert utils.parse_default_path(path, check_existence=True) == str(Path.home())
long_path = Path.home()
long_path = (
diff --git a/napari_cellseg3d/_tests/test_weight_download.py b/napari_cellseg3d/_tests/test_weight_download.py
index e1392436..464785b4 100644
--- a/napari_cellseg3d/_tests/test_weight_download.py
+++ b/napari_cellseg3d/_tests/test_weight_download.py
@@ -1,10 +1,17 @@
+import os
+
+import pytest
+
from napari_cellseg3d.code_models.workers_utils import (
PRETRAINED_WEIGHTS_DIR,
WeightsDownloader,
)
-# DISABLED, causes GitHub actions to freeze
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS") == "true",
+ reason="This test causes GitHub Actions to freeze",
+)
def test_weight_download():
downloader = WeightsDownloader()
downloader.download_weights("test", "test.pth")
diff --git a/napari_cellseg3d/code_models/crf.py b/napari_cellseg3d/code_models/crf.py
index 14a8a373..928cc8f6 100644
--- a/napari_cellseg3d/code_models/crf.py
+++ b/napari_cellseg3d/code_models/crf.py
@@ -38,6 +38,7 @@
unary_from_softmax,
)
+
def correct_shape_for_crf(image, desired_dims=4):
"""Corrects the shape of the image to be compatible with the CRF post-processing step."""
logger.debug(f"Correcting shape for CRF, desired_dims={desired_dims}")
@@ -104,9 +105,7 @@ def crf(image, prob, sa, sb, sg, w1, w2, n_iter=5):
)
return None
- d = dcrf.DenseCRF(
- image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0]
- )
+ d = dcrf.DenseCRF(image.shape[1] * image.shape[2] * image.shape[3], prob.shape[0])
# Get unary potentials from softmax probabilities
U = unary_from_softmax(prob)
diff --git a/napari_cellseg3d/code_models/instance_segmentation.py b/napari_cellseg3d/code_models/instance_segmentation.py
index 754a996f..d5efec29 100644
--- a/napari_cellseg3d/code_models/instance_segmentation.py
+++ b/napari_cellseg3d/code_models/instance_segmentation.py
@@ -59,9 +59,7 @@ def __init__(
self.function = function
self.counters: List[ui.DoubleIncrementCounter] = []
self.sliders: List[ui.Slider] = []
- self._setup_widgets(
- num_counters, num_sliders, widget_parent=widget_parent
- )
+ self._setup_widgets(num_counters, num_sliders, widget_parent=widget_parent)
self.recorded_parameters = {}
"""Stores the parameters when calling self.record_parameters()"""
@@ -97,9 +95,7 @@ def _setup_widgets(self, num_counters, num_sliders, widget_parent=None):
setattr(
self,
widget,
- ui.DoubleIncrementCounter(
- text_label="", parent=widget_parent
- ),
+ ui.DoubleIncrementCounter(text_label="", parent=widget_parent),
)
self.counters.append(getattr(self, widget))
@@ -129,14 +125,10 @@ def record_parameters(self):
"""Records all the parameters of the instance segmentation method from the current values of the widgets."""
if len(self.sliders) > 0:
for slider in self.sliders:
- self.recorded_parameters[slider.label.text()] = (
- slider.slider_value
- )
+ self.recorded_parameters[slider.label.text()] = slider.slider_value
if len(self.counters) > 0:
for counter in self.counters:
- self.recorded_parameters[counter.label.text()] = (
- counter.value()
- )
+ self.recorded_parameters[counter.label.text()] = counter.value()
def run_method_from_params(self, image):
"""Runs the method on the image with the RECORDED parameters set in the widget.
@@ -149,14 +141,10 @@ def run_method_from_params(self, image):
Returns: processed image from self._method
"""
if len(self.recorded_parameters) == 0:
- logger.warning(
- "No parameters recorded, running with values from widgets"
- )
+ logger.warning("No parameters recorded, running with values from widgets")
self.record_parameters()
- parameters = [
- self.recorded_parameters[key] for key in self.recorded_parameters
- ]
+ parameters = [self.recorded_parameters[key] for key in self.recorded_parameters]
assert len(parameters) == len(self.sliders) + len(self.counters), (
f"Number of parameters recorded ({len(parameters)}) "
@@ -187,9 +175,7 @@ def run_method_on_channels_from_params(self, image):
Returns: processed image from self._method
"""
image_list = self._make_list_from_channels(image)
- result = np.array(
- [self.run_method_from_params(im) for im in image_list]
- )
+ result = np.array([self.run_method_from_params(im) for im in image_list])
return result.squeeze()
@staticmethod
@@ -389,9 +375,7 @@ def clear_large_objects(image, large_label_size=200, use_window=True):
thres_small=large_label_size,
rem_seed_thres=0,
)
- res = InstanceMethod.sliding_window(
- image, func, increment_labels=False
- )
+ res = InstanceMethod.sliding_window(image, func, increment_labels=False)
return np.where(res > 0, 0, image)
labeled = binary_watershed(
@@ -545,9 +529,7 @@ def __init__(self, widget_parent=None):
)
self.sliders[0].label.setText("Foreground probability threshold")
- self.sliders[0].tooltips = (
- "Probability threshold for foreground object"
- )
+ self.sliders[0].tooltips = "Probability threshold for foreground object"
self.sliders[0].setValue(500)
self.sliders[1].label.setText("Seed probability threshold")
@@ -646,9 +628,7 @@ def __init__(self, widget_parent=None):
)
self.sliders[0].label.setText("Foreground probability threshold")
- self.sliders[0].tooltips = (
- "Probability threshold for foreground object"
- )
+ self.sliders[0].tooltips = "Probability threshold for foreground object"
self.sliders[0].setValue(800)
self.counters[0].label.setText("Small objects removal")
@@ -709,16 +689,12 @@ def __init__(self, widget_parent=None):
widget_parent=widget_parent,
)
self.counters[0].label.setText("Spot sigma") # closeness
- self.counters[0].tooltips = (
- "Determines how close detected objects can be"
- )
+ self.counters[0].tooltips = "Determines how close detected objects can be"
self.counters[0].setMaximum(100)
self.counters[0].setValue(0.65)
self.counters[1].label.setText("Outline sigma") # smoothness
- self.counters[1].tooltips = (
- "Determines the smoothness of the segmentation"
- )
+ self.counters[1].tooltips = "Determines the smoothness of the segmentation"
self.counters[1].setMaximum(100)
self.counters[1].setValue(0.65)
@@ -823,9 +799,7 @@ def _build(self):
group.layout.addWidget(counter)
self.instance_widgets[name].append(counter)
except RuntimeError as e:
- logger.debug(
- f"Caught runtime error {e}, most likely during testing"
- )
+ logger.debug(f"Caught runtime error {e}, most likely during testing")
self.setLayout(group.layout)
self._set_visibility()
diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py
index 7caff7b6..fadf9c40 100644
--- a/napari_cellseg3d/code_models/model_framework.py
+++ b/napari_cellseg3d/code_models/model_framework.py
@@ -1,4 +1,5 @@
"""Basic napari plugin framework for inference and training."""
+
from pathlib import Path
from typing import TYPE_CHECKING
@@ -48,9 +49,7 @@ def __init__(
loads_labels: if True, will contain UI elements used to load napari label layers
has_results: if True, will add UI to choose a results path
"""
- super().__init__(
- viewer, parent, loads_images, loads_labels, has_results
- )
+ super().__init__(viewer, parent, loads_images, loads_labels, has_results)
self._viewer = viewer
"""Viewer to display the widget in"""
@@ -107,9 +106,7 @@ def __init__(
self.report_container = ui.ContainerWidget(l=10, t=5, r=5, b=5)
- self.report_container.setSizePolicy(
- QSizePolicy.Fixed, QSizePolicy.Minimum
- )
+ self.report_container.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Minimum)
self.container_docked = False # check if already docked
self.progress = QProgressBar(self.report_container)
@@ -240,9 +237,7 @@ def display_status_report(self):
def _toggle_weights_path(self):
"""Toggle visibility of weight path."""
- ui.toggle_visibility(
- self.custom_weights_choice, self.weights_filewidget
- )
+ ui.toggle_visibility(self.custom_weights_choice, self.weights_filewidget)
def get_unsupervised_image_filepaths(self):
"""Returns a list of filepaths to images in the unsupervised images folder."""
diff --git a/napari_cellseg3d/code_models/models/TEMPLATE_model.py b/napari_cellseg3d/code_models/models/TEMPLATE_model.py
index 7c33adf4..0cd08141 100644
--- a/napari_cellseg3d/code_models/models/TEMPLATE_model.py
+++ b/napari_cellseg3d/code_models/models/TEMPLATE_model.py
@@ -15,9 +15,7 @@ class ModelTemplate_(ABC):
default_threshold = 0.5 # specify the default threshold for the model
@abstractmethod
- def __init__(
- self, input_image_size, in_channels=1, out_channels=1, **kwargs
- ):
+ def __init__(self, input_image_size, in_channels=1, out_channels=1, **kwargs):
"""Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported."""
pass
diff --git a/napari_cellseg3d/code_models/models/model_SegResNet.py b/napari_cellseg3d/code_models/models/model_SegResNet.py
index ef1e7492..4e447a01 100644
--- a/napari_cellseg3d/code_models/models/model_SegResNet.py
+++ b/napari_cellseg3d/code_models/models/model_SegResNet.py
@@ -9,9 +9,7 @@ class SegResNet_(SegResNetVAE):
weights_file = "SegResNet_latest.pth"
default_threshold = 0.3
- def __init__(
- self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs
- ):
+ def __init__(self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs):
"""Create a SegResNet model.
Args:
diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP.py b/napari_cellseg3d/code_models/models/model_TRAILMAP.py
index 6673d1d1..1389b1a1 100644
--- a/napari_cellseg3d/code_models/models/model_TRAILMAP.py
+++ b/napari_cellseg3d/code_models/models/model_TRAILMAP.py
@@ -1,4 +1,5 @@
"""Legacy version of adapted TRAILMAP model, not used in the current version of the plugin."""
+
# import torch
# from torch import nn
#
diff --git a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py
index 2735c871..e95bcaef 100644
--- a/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py
+++ b/napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py
@@ -26,9 +26,7 @@ def __init__(self, in_channels=1, out_channels=1, **kwargs):
)
except TypeError as e:
logger.warning(f"Caught TypeError: {e}")
- super().__init__(
- in_channels=in_channels, out_channels=out_channels
- )
+ super().__init__(in_channels=in_channels, out_channels=out_channels)
# def get_output(self, input):
# out = self(input)
diff --git a/napari_cellseg3d/code_models/models/unet/buildingblocks.py b/napari_cellseg3d/code_models/models/unet/buildingblocks.py
index ce7d378f..d9c68e47 100644
--- a/napari_cellseg3d/code_models/models/unet/buildingblocks.py
+++ b/napari_cellseg3d/code_models/models/unet/buildingblocks.py
@@ -6,16 +6,11 @@
def conv3d(in_channels, out_channels, kernel_size, bias, padding):
- return nn.Conv3d(
- in_channels, out_channels, kernel_size, padding=padding, bias=bias
- )
+ return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
-def create_conv(
- in_channels, out_channels, kernel_size, order, num_groups, padding
-):
- """
- Create a list of modules with together constitute a single conv layer with non-linearity
+def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding):
+ """Create a list of modules with together constitute a single conv layer with non-linearity
and optional batchnorm/groupnorm.
Args:
@@ -35,9 +30,9 @@ def create_conv(
list of tuple (name, module)
"""
assert "c" in order, "Conv layer MUST be present"
- assert (
- order[0] not in "rle"
- ), "Non-linearity cannot be the first operation in the layer"
+ assert order[0] not in "rle", (
+ "Non-linearity cannot be the first operation in the layer"
+ )
modules = []
for i, char in enumerate(order):
@@ -70,15 +65,13 @@ def create_conv(
if num_channels < num_groups:
num_groups = 1
- assert (
- num_channels % num_groups == 0
- ), f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
+ assert num_channels % num_groups == 0, (
+ f"Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}"
+ )
modules.append(
(
"groupnorm",
- nn.GroupNorm(
- num_groups=num_groups, num_channels=num_channels
- ),
+ nn.GroupNorm(num_groups=num_groups, num_channels=num_channels),
)
)
elif char == "b":
@@ -96,8 +89,7 @@ def create_conv(
class SingleConv(nn.Sequential):
- """
- Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
+ """Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
of operations can be specified via the `order` parameter
Args:
@@ -136,8 +128,7 @@ def __init__(
class DoubleConv(nn.Sequential):
- """
- A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
+ """A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
We use (Conv3d+ReLU+GroupNorm3d) by default.
This can be changed however by providing the 'order' argument, e.g. in order
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
@@ -211,8 +202,7 @@ def __init__(
class ExtResNetBlock(nn.Module):
- """
- Basic UNet block consisting of a SingleConv followed by the residual block.
+ """Basic UNet block consisting of a SingleConv followed by the residual block.
The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
of output channels is compatible with the residual block that follows.
This block can be used instead of standard DoubleConv in the Encoder module.
@@ -278,18 +268,16 @@ def forward(self, x):
out = self.conv3(out)
out += residual
- out = self.non_linearity(out)
-
- return out
+ return self.non_linearity(out)
class Encoder(nn.Module):
- """
- A single module from the encoder path consisting of the optional max
+ """A single module from the encoder path consisting of the optional max
pooling layer (one may specify the MaxPool kernel_size to be different
than the standard (2,2,2), e.g. if the volumetric data is anisotropic
(make sure to use complementary scale_factor in the decoder path) followed by
a DoubleConv module.
+
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
@@ -340,14 +328,13 @@ def __init__(
def forward(self, x):
if self.pooling is not None:
x = self.pooling(x)
- x = self.basic_module(x)
- return x
+ return self.basic_module(x)
class Decoder(nn.Module):
- """
- A single module for decoder path consisting of the upsampling layer
+ """A single module for decoder path consisting of the upsampling layer
(either learned ConvTranspose3d or nearest neighbor interpolation) followed by a basic module (DoubleConv or ExtResNetBlock).
+
Args:
in_channels (int): number of input channels
out_channels (int): number of output channels
@@ -415,8 +402,7 @@ def __init__(
def forward(self, encoder_features, x):
x = self.upsampling(encoder_features=encoder_features, x=x)
x = self.joining(encoder_features, x)
- x = self.basic_module(x)
- return x
+ return self.basic_module(x)
@staticmethod
def _joining(encoder_features, x, concat):
@@ -510,8 +496,7 @@ def create_decoders(
class AbstractUpsampling(nn.Module):
- """
- Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
+ """Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
interpolation or learned transposed convolution.
"""
@@ -527,11 +512,10 @@ def forward(self, encoder_features, x):
class InterpolateUpsampling(AbstractUpsampling):
- """
- Args:
- mode (str): algorithm used for upsampling:
- 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
- used only if transposed_conv is False
+ """Args:
+ mode (str): algorithm used for upsampling:
+ 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
+ used only if transposed_conv is False
"""
def __init__(self, mode="nearest"):
@@ -544,16 +528,15 @@ def _interpolate(x, size, mode):
class TransposeConvUpsampling(AbstractUpsampling):
- """
- Args:
- in_channels (int): number of input channels for transposed conv
- used only if transposed_conv is True
- out_channels (int): number of output channels for transpose conv
- used only if transposed_conv is True
- kernel_size (int or tuple): size of the convolving kernel
- used only if transposed_conv is True
- scale_factor (int or tuple): stride of the convolution
- used only if transposed_conv is True
+ """Args:
+ in_channels (int): number of input channels for transposed conv
+ used only if transposed_conv is True
+ out_channels (int): number of output channels for transpose conv
+ used only if transposed_conv is True
+ kernel_size (int or tuple): size of the convolving kernel
+ used only if transposed_conv is True
+ scale_factor (int or tuple): stride of the convolution
+ used only if transposed_conv is True
"""
diff --git a/napari_cellseg3d/code_models/models/unet/model.py b/napari_cellseg3d/code_models/models/unet/model.py
index 9591d054..d4d4deb3 100644
--- a/napari_cellseg3d/code_models/models/unet/model.py
+++ b/napari_cellseg3d/code_models/models/unet/model.py
@@ -12,8 +12,7 @@ def number_of_features_per_level(init_channel_number, num_levels):
class Abstract3DUNet(nn.Module):
- """
- Base class for standard and residual UNet.
+ """Base class for standard and residual UNet.
Args:
in_channels (int): number of input channels
@@ -60,9 +59,7 @@ def __init__(
super(Abstract3DUNet, self).__init__()
if isinstance(f_maps, int):
- f_maps = number_of_features_per_level(
- f_maps, num_levels=num_levels
- )
+ f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
assert isinstance(f_maps, (list, tuple))
assert len(f_maps) > 1, "Required at least 2 levels in the U-Net"
@@ -132,8 +129,7 @@ def forward(self, x):
class UNet3D(Abstract3DUNet):
- """
- 3DUnet model from
+ """3DUnet model from
`"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
update_plot
def update_plot(z)
<no docstring>
| \n", + " | Volume | \n", + "Centroid x | \n", + "Centroid y | \n", + "Centroid z | \n", + "Sphericity (axes) | \n", + "Image size | \n", + "Total image volume | \n", + "Total object volume (pixels) | \n", + "Filling ratio | \n", + "Number objects | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "190.0 | \n", + "5.405263 | \n", + "69.157895 | \n", + "36.210526 | \n", + "0.778113 | \n", + "(124, 86, 94) | \n", + "1002416 | \n", + "33504.0 | \n", + "0.033423 | \n", + "322 | \n", + "
| 1 | \n", + "18.0 | \n", + "5.833333 | \n", + "85.000000 | \n", + "83.944444 | \n", + "0.000007 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 2 | \n", + "67.0 | \n", + "7.283582 | \n", + "65.492537 | \n", + "92.059701 | \n", + "0.867751 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 3 | \n", + "108.0 | \n", + "10.324074 | \n", + "84.342593 | \n", + "68.861111 | \n", + "0.672490 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 4 | \n", + "35.0 | \n", + "9.428571 | \n", + "84.314286 | \n", + "92.600000 | \n", + "0.649649 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| ... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
| 317 | \n", + "11.0 | \n", + "122.363636 | \n", + "14.727273 | \n", + "25.000000 | \n", + "0.951651 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 318 | \n", + "24.0 | \n", + "122.166667 | \n", + "26.083333 | \n", + "38.083333 | \n", + "0.990075 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 319 | \n", + "16.0 | \n", + "122.125000 | \n", + "34.125000 | \n", + "36.500000 | \n", + "0.944672 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 320 | \n", + "13.0 | \n", + "122.076923 | \n", + "43.538462 | \n", + "53.615385 | \n", + "0.939852 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
| 321 | \n", + "21.0 | \n", + "122.523810 | \n", + "49.666667 | \n", + "36.238095 | \n", + "0.895437 | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
322 rows × 10 columns
\n", + "update_plot
def update_plot(z)
<no docstring>
| \n", - " | Volume | \n", - "Centroid x | \n", - "Centroid y | \n", - "Centroid z | \n", - "Sphericity (axes) | \n", - "Image size | \n", - "Total image volume | \n", - "Total object volume (pixels) | \n", - "Filling ratio | \n", - "Number objects | \n", - "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", - "190.0 | \n", - "5.405263 | \n", - "69.157895 | \n", - "36.210526 | \n", - "0.778113 | \n", - "(124, 86, 94) | \n", - "1002416 | \n", - "33504.0 | \n", - "0.033423 | \n", - "322 | \n", - "
| 1 | \n", - "18.0 | \n", - "5.833333 | \n", - "85.000000 | \n", - "83.944444 | \n", - "0.000007 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 2 | \n", - "67.0 | \n", - "7.283582 | \n", - "65.492537 | \n", - "92.059701 | \n", - "0.867751 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 3 | \n", - "108.0 | \n", - "10.324074 | \n", - "84.342593 | \n", - "68.861111 | \n", - "0.672490 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 4 | \n", - "35.0 | \n", - "9.428571 | \n", - "84.314286 | \n", - "92.600000 | \n", - "0.649649 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| ... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "... | \n", - "
| 317 | \n", - "11.0 | \n", - "122.363636 | \n", - "14.727273 | \n", - "25.000000 | \n", - "0.951651 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 318 | \n", - "24.0 | \n", - "122.166667 | \n", - "26.083333 | \n", - "38.083333 | \n", - "0.990075 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 319 | \n", - "16.0 | \n", - "122.125000 | \n", - "34.125000 | \n", - "36.500000 | \n", - "0.944672 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 320 | \n", - "13.0 | \n", - "122.076923 | \n", - "43.538462 | \n", - "53.615385 | \n", - "0.939852 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
| 321 | \n", - "21.0 | \n", - "122.523810 | \n", - "49.666667 | \n", - "36.238095 | \n", - "0.895437 | \n", - "\n", - " | \n", - " | \n", - " | \n", - " | \n", - " |
322 rows × 10 columns
\n", - "